File size: 9,941 Bytes
1faccd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import multiprocessing
import sys
import threading
import time

import pytest  # Import pytest

from verl.utils.py_functional import timeout_limit as timeout

# --- Test Task Functions ---
TEST_TIMEOUT_SECONDS = 1.5  # Timeout duration for tests
LONG_TASK_DURATION = TEST_TIMEOUT_SECONDS + 0.5  # Duration slightly longer than timeout


@timeout(seconds=TEST_TIMEOUT_SECONDS)  # Keep global decorator for mp tests
def quick_task(x):
    """A task that completes quickly."""
    time.sleep(0.1)
    return "quick_ok"


@timeout(seconds=TEST_TIMEOUT_SECONDS)  # Keep global decorator for mp tests
def slow_task(x):
    """A task that takes longer than the timeout."""
    time.sleep(LONG_TASK_DURATION)
    return "slow_finished"  # This return value indicates it didn't time out


# REMOVE global decorator here
def task_raises_value_error():  # Now truly not globally decorated
    """A task that intentionally raises a ValueError."""
    raise ValueError("Specific value error from task")


# --- Top-level function for signal test in subprocess ---
# Keep this decorated globally for the specific subprocess test case
@timeout(seconds=TEST_TIMEOUT_SECONDS, use_signals=True)
def top_level_decorated_quick_task_signal():
    """A pickleable top-level function decorated with signal timeout."""
    # Assuming this calls the logic of quick_task directly for the test purpose
    time.sleep(0.1)
    return "quick_ok_signal_subprocess"  # Different return for clarity if needed


# --- Top-level function for signal test in subprocess ---
# Keep this decorated globally for the specific subprocess test case
@timeout(seconds=TEST_TIMEOUT_SECONDS, use_signals=True)
def top_level_decorated_slow_task_signal():
    """A pickleable top-level function decorated with signal timeout."""
    time.sleep(LONG_TASK_DURATION)
    return "slow_finished"


# --- NEW: Top-level helper function to run target in process ---
def run_target_and_put_in_queue(target_func, q):
    """
    Top-level helper function to run a target function and put its result or exception into a queue.
    This function is pickleable and can be used as the target for multiprocessing.Process.
    """
    try:
        result = target_func()
        q.put(("success", result))
    except Exception as e:
        q.put(("error", e))


# Use a module-level fixture to set the start method on macOS
@pytest.fixture(scope="module", autouse=True)  # Changed scope to module
def set_macos_start_method():
    if sys.platform == "darwin":
        # Force fork method on macOS to avoid pickling issues with globally decorated functions
        # when running tests via pytest discovery.
        current_method = multiprocessing.get_start_method(allow_none=True)
        # Only set if not already set or if set to something else (less likely in test run)
        if current_method is None or current_method != "fork":
            try:
                multiprocessing.set_start_method("fork", force=True)
            except RuntimeError:
                # Might fail if context is already started, ignore in that case.
                pass


def test_quick_task():  # Renamed from test_multiprocessing_quick_task
    """Tests timeout handles a quick task correctly."""
    # Call the globally decorated function directly
    result = quick_task(1)
    assert result == "quick_ok"  # Use pytest assert


def test_slow_task_timeout():  # Renamed from test_multiprocessing_slow_task_timeout
    """Tests timeout correctly raises TimeoutError for a slow task."""
    # Call the globally decorated function directly within pytest.raises
    with pytest.raises(TimeoutError) as excinfo:  # Use pytest.raises
        slow_task(1)
    # Check the error message from the multiprocessing implementation
    assert f"timed out after {TEST_TIMEOUT_SECONDS} seconds" in str(excinfo.value)  # Use pytest assert


def test_internal_exception():  # Renamed from test_multiprocessing_internal_exception
    """Tests timeout correctly propagates internal exceptions."""
    # Apply the default timeout decorator dynamically to the undecorated function
    decorated_task = timeout(seconds=TEST_TIMEOUT_SECONDS)(task_raises_value_error)  # Apply decorator dynamically
    with pytest.raises(ValueError) as excinfo:  # Use pytest.raises
        decorated_task()  # Call the dynamically decorated function
    assert str(excinfo.value) == "Specific value error from task"  # Use pytest assert


# --- Test the signal implementation (use_signals=True) ---
# Note: As per py_functional.py, use_signals=True currently falls back to
# multiprocessing on POSIX. These tests verify that behavior.


def test_signal_quick_task_main_process():  # Removed self
    """Tests signal timeout handles a quick task correctly in the main process."""

    # Apply the signal decorator dynamically
    def plain_quick_task_logic():
        time.sleep(0.1)
        return "quick_ok_signal"

    decorated_task = timeout(seconds=TEST_TIMEOUT_SECONDS, use_signals=True)(plain_quick_task_logic)
    assert decorated_task() == "quick_ok_signal"  # Use pytest assert


def test_signal_slow_task_main_process_timeout():  # Removed self
    """Tests signal timeout correctly raises TimeoutError for a slow task in the main process."""

    # Apply the signal decorator dynamically
    def plain_slow_task_logic():
        time.sleep(LONG_TASK_DURATION)
        return "slow_finished_signal"

    decorated_task = timeout(seconds=TEST_TIMEOUT_SECONDS, use_signals=True)(plain_slow_task_logic)
    with pytest.raises(TimeoutError) as excinfo:  # Use pytest.raises
        decorated_task()
    # Check the error message (falls back to multiprocessing message on POSIX)
    assert f"timed out after {TEST_TIMEOUT_SECONDS} seconds" in str(excinfo.value)  # Use pytest assert


@pytest.mark.skip(reason="this test won't pass. Just to show why use_signals should not be used")
def test_signal_in_thread_does_not_timeout():
    """
    Tests that signal-based timeout does NOT work reliably in a child thread.
    The TimeoutError from the signal handler is not expected to be raised.
    """
    result_container = []  # Use a list to store result from thread
    exception_container = []  # Use a list to store exception from thread

    @timeout(seconds=TEST_TIMEOUT_SECONDS, use_signals=True)
    def slow_task_in_thread():
        try:
            print("Thread: Starting slow task...")
            time.sleep(LONG_TASK_DURATION)
            print("Thread: Slow task finished.")
            return "slow_finished_in_thread"
        except Exception as e:
            # Catch any exception within the thread's target function
            print(f"Thread: Caught exception: {e}")
            exception_container.append(e)
            return None  # Indicate failure

    def thread_target():
        try:
            # Run the decorated function inside the thread
            res = slow_task_in_thread()
            if res is not None:
                result_container.append(res)
        except Exception as e:
            # This might catch exceptions happening *outside* the decorated function
            # but still within the thread target, though less likely here.
            print(f"Thread Target: Caught exception: {e}")
            exception_container.append(e)

    thread = threading.Thread(target=thread_target)
    print("Main: Starting thread...")
    thread.start()
    # Wait longer than the timeout + task duration to ensure the thread finishes
    # regardless of whether timeout worked or not.
    thread.join(timeout=LONG_TASK_DURATION + 1)

    assert len(exception_container) == 1
    assert isinstance(exception_container[0], TimeoutError)
    assert not result_container


def test_in_thread_timeout():
    result_container = []  # Use a list to store result from thread
    exception_container = []  # Use a list to store exception from thread

    @timeout(seconds=TEST_TIMEOUT_SECONDS, use_signals=False)
    def slow_task_in_thread():
        try:
            print("Thread: Starting slow task...")
            time.sleep(LONG_TASK_DURATION)
            print("Thread: Slow task finished.")
            return "slow_finished_in_thread"
        except Exception as e:
            # Catch any exception within the thread's target function
            print(f"Thread: Caught exception: {e}")
            exception_container.append(e)
            return None  # Indicate failure

    def thread_target():
        try:
            # Run the decorated function inside the thread
            res = slow_task_in_thread()
            if res is not None:
                result_container.append(res)
        except Exception as e:
            # This might catch exceptions happening *outside* the decorated function
            # but still within the thread target, though less likely here.
            print(f"Thread Target: Caught exception: {e}")
            exception_container.append(e)

    thread = threading.Thread(target=thread_target)
    print("Main: Starting thread...")
    thread.start()
    # Wait longer than the timeout + task duration to ensure the thread finishes
    # regardless of whether timeout worked or not.
    thread.join(timeout=LONG_TASK_DURATION + 1)

    assert len(exception_container) == 1
    assert isinstance(exception_container[0], TimeoutError)
    assert not result_container