File size: 5,370 Bytes
9aa5185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
"""Tests for utils.atomic_json_write — crash-safe JSON file writes."""

import json
import os
from pathlib import Path
from unittest.mock import patch

import pytest

from utils import atomic_json_write


class TestAtomicJsonWrite:
    """Core atomic write behavior."""

    def test_writes_valid_json(self, tmp_path):
        target = tmp_path / "data.json"
        data = {"key": "value", "nested": {"a": 1}}
        atomic_json_write(target, data)

        result = json.loads(target.read_text(encoding="utf-8"))
        assert result == data

    def test_creates_parent_directories(self, tmp_path):
        target = tmp_path / "deep" / "nested" / "dir" / "data.json"
        atomic_json_write(target, {"ok": True})

        assert target.exists()
        assert json.loads(target.read_text())["ok"] is True

    def test_overwrites_existing_file(self, tmp_path):
        target = tmp_path / "data.json"
        target.write_text('{"old": true}')

        atomic_json_write(target, {"new": True})
        result = json.loads(target.read_text())
        assert result == {"new": True}

    def test_preserves_original_on_serialization_error(self, tmp_path):
        target = tmp_path / "data.json"
        original = {"preserved": True}
        target.write_text(json.dumps(original))

        # Try to write non-serializable data — should fail
        with pytest.raises(TypeError):
            atomic_json_write(target, {"bad": object()})

        # Original file should be untouched
        result = json.loads(target.read_text())
        assert result == original

    def test_no_leftover_temp_files_on_success(self, tmp_path):
        target = tmp_path / "data.json"
        atomic_json_write(target, [1, 2, 3])

        # No .tmp files should be left behind
        tmp_files = [f for f in tmp_path.iterdir() if ".tmp" in f.name]
        assert len(tmp_files) == 0
        assert target.exists()

    def test_no_leftover_temp_files_on_failure(self, tmp_path):
        target = tmp_path / "data.json"

        with pytest.raises(TypeError):
            atomic_json_write(target, {"bad": object()})

        # No temp files should be left behind
        tmp_files = [f for f in tmp_path.iterdir() if ".tmp" in f.name]
        assert len(tmp_files) == 0

    def test_cleans_up_temp_file_on_baseexception(self, tmp_path):
        class SimulatedAbort(BaseException):
            pass

        target = tmp_path / "data.json"
        original = {"preserved": True}
        target.write_text(json.dumps(original), encoding="utf-8")

        with patch("utils.json.dump", side_effect=SimulatedAbort):
            with pytest.raises(SimulatedAbort):
                atomic_json_write(target, {"new": True})

        tmp_files = [f for f in tmp_path.iterdir() if ".tmp" in f.name]
        assert len(tmp_files) == 0
        assert json.loads(target.read_text(encoding="utf-8")) == original

    def test_accepts_string_path(self, tmp_path):
        target = str(tmp_path / "string_path.json")
        atomic_json_write(target, {"string": True})

        result = json.loads(Path(target).read_text())
        assert result == {"string": True}

    def test_writes_list_data(self, tmp_path):
        target = tmp_path / "list.json"
        data = [1, "two", {"three": 3}]
        atomic_json_write(target, data)

        result = json.loads(target.read_text())
        assert result == data

    def test_empty_list(self, tmp_path):
        target = tmp_path / "empty.json"
        atomic_json_write(target, [])

        result = json.loads(target.read_text())
        assert result == []

    def test_custom_indent(self, tmp_path):
        target = tmp_path / "custom.json"
        atomic_json_write(target, {"a": 1}, indent=4)

        text = target.read_text()
        assert '    "a"' in text  # 4-space indent

    def test_accepts_json_dump_default_hook(self, tmp_path):
        class CustomValue:
            def __str__(self):
                return "custom-value"

        target = tmp_path / "custom_default.json"
        atomic_json_write(target, {"value": CustomValue()}, default=str)

        result = json.loads(target.read_text(encoding="utf-8"))
        assert result == {"value": "custom-value"}

    def test_unicode_content(self, tmp_path):
        target = tmp_path / "unicode.json"
        data = {"emoji": "🎉", "japanese": "日本語"}
        atomic_json_write(target, data)

        result = json.loads(target.read_text(encoding="utf-8"))
        assert result["emoji"] == "🎉"
        assert result["japanese"] == "日本語"

    def test_concurrent_writes_dont_corrupt(self, tmp_path):
        """Multiple rapid writes should each produce valid JSON."""
        import threading

        target = tmp_path / "concurrent.json"
        errors = []

        def writer(n):
            try:
                atomic_json_write(target, {"writer": n, "data": list(range(100))})
            except Exception as e:
                errors.append(e)

        threads = [threading.Thread(target=writer, args=(i,)) for i in range(10)]
        for t in threads:
            t.start()
        for t in threads:
            t.join()

        assert not errors
        # File should contain valid JSON from one of the writers
        result = json.loads(target.read_text())
        assert "writer" in result
        assert len(result["data"]) == 100