File size: 9,941 Bytes
1faccd4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import multiprocessing
import sys
import threading
import time
import pytest # Import pytest
from verl.utils.py_functional import timeout_limit as timeout
# --- Test Task Functions ---
TEST_TIMEOUT_SECONDS = 1.5 # Timeout duration for tests
LONG_TASK_DURATION = TEST_TIMEOUT_SECONDS + 0.5 # Duration slightly longer than timeout
@timeout(seconds=TEST_TIMEOUT_SECONDS) # Keep global decorator for mp tests
def quick_task(x):
"""A task that completes quickly."""
time.sleep(0.1)
return "quick_ok"
@timeout(seconds=TEST_TIMEOUT_SECONDS) # Keep global decorator for mp tests
def slow_task(x):
"""A task that takes longer than the timeout."""
time.sleep(LONG_TASK_DURATION)
return "slow_finished" # This return value indicates it didn't time out
# REMOVE global decorator here
def task_raises_value_error(): # Now truly not globally decorated
"""A task that intentionally raises a ValueError."""
raise ValueError("Specific value error from task")
# --- Top-level function for signal test in subprocess ---
# Keep this decorated globally for the specific subprocess test case
@timeout(seconds=TEST_TIMEOUT_SECONDS, use_signals=True)
def top_level_decorated_quick_task_signal():
"""A pickleable top-level function decorated with signal timeout."""
# Assuming this calls the logic of quick_task directly for the test purpose
time.sleep(0.1)
return "quick_ok_signal_subprocess" # Different return for clarity if needed
# --- Top-level function for signal test in subprocess ---
# Keep this decorated globally for the specific subprocess test case
@timeout(seconds=TEST_TIMEOUT_SECONDS, use_signals=True)
def top_level_decorated_slow_task_signal():
"""A pickleable top-level function decorated with signal timeout."""
time.sleep(LONG_TASK_DURATION)
return "slow_finished"
# --- NEW: Top-level helper function to run target in process ---
def run_target_and_put_in_queue(target_func, q):
"""
Top-level helper function to run a target function and put its result or exception into a queue.
This function is pickleable and can be used as the target for multiprocessing.Process.
"""
try:
result = target_func()
q.put(("success", result))
except Exception as e:
q.put(("error", e))
# Use a module-level fixture to set the start method on macOS
@pytest.fixture(scope="module", autouse=True) # Changed scope to module
def set_macos_start_method():
if sys.platform == "darwin":
# Force fork method on macOS to avoid pickling issues with globally decorated functions
# when running tests via pytest discovery.
current_method = multiprocessing.get_start_method(allow_none=True)
# Only set if not already set or if set to something else (less likely in test run)
if current_method is None or current_method != "fork":
try:
multiprocessing.set_start_method("fork", force=True)
except RuntimeError:
# Might fail if context is already started, ignore in that case.
pass
def test_quick_task(): # Renamed from test_multiprocessing_quick_task
"""Tests timeout handles a quick task correctly."""
# Call the globally decorated function directly
result = quick_task(1)
assert result == "quick_ok" # Use pytest assert
def test_slow_task_timeout(): # Renamed from test_multiprocessing_slow_task_timeout
"""Tests timeout correctly raises TimeoutError for a slow task."""
# Call the globally decorated function directly within pytest.raises
with pytest.raises(TimeoutError) as excinfo: # Use pytest.raises
slow_task(1)
# Check the error message from the multiprocessing implementation
assert f"timed out after {TEST_TIMEOUT_SECONDS} seconds" in str(excinfo.value) # Use pytest assert
def test_internal_exception(): # Renamed from test_multiprocessing_internal_exception
"""Tests timeout correctly propagates internal exceptions."""
# Apply the default timeout decorator dynamically to the undecorated function
decorated_task = timeout(seconds=TEST_TIMEOUT_SECONDS)(task_raises_value_error) # Apply decorator dynamically
with pytest.raises(ValueError) as excinfo: # Use pytest.raises
decorated_task() # Call the dynamically decorated function
assert str(excinfo.value) == "Specific value error from task" # Use pytest assert
# --- Test the signal implementation (use_signals=True) ---
# Note: As per py_functional.py, use_signals=True currently falls back to
# multiprocessing on POSIX. These tests verify that behavior.
def test_signal_quick_task_main_process(): # Removed self
"""Tests signal timeout handles a quick task correctly in the main process."""
# Apply the signal decorator dynamically
def plain_quick_task_logic():
time.sleep(0.1)
return "quick_ok_signal"
decorated_task = timeout(seconds=TEST_TIMEOUT_SECONDS, use_signals=True)(plain_quick_task_logic)
assert decorated_task() == "quick_ok_signal" # Use pytest assert
def test_signal_slow_task_main_process_timeout(): # Removed self
"""Tests signal timeout correctly raises TimeoutError for a slow task in the main process."""
# Apply the signal decorator dynamically
def plain_slow_task_logic():
time.sleep(LONG_TASK_DURATION)
return "slow_finished_signal"
decorated_task = timeout(seconds=TEST_TIMEOUT_SECONDS, use_signals=True)(plain_slow_task_logic)
with pytest.raises(TimeoutError) as excinfo: # Use pytest.raises
decorated_task()
# Check the error message (falls back to multiprocessing message on POSIX)
assert f"timed out after {TEST_TIMEOUT_SECONDS} seconds" in str(excinfo.value) # Use pytest assert
@pytest.mark.skip(reason="this test won't pass. Just to show why use_signals should not be used")
def test_signal_in_thread_does_not_timeout():
"""
Tests that signal-based timeout does NOT work reliably in a child thread.
The TimeoutError from the signal handler is not expected to be raised.
"""
result_container = [] # Use a list to store result from thread
exception_container = [] # Use a list to store exception from thread
@timeout(seconds=TEST_TIMEOUT_SECONDS, use_signals=True)
def slow_task_in_thread():
try:
print("Thread: Starting slow task...")
time.sleep(LONG_TASK_DURATION)
print("Thread: Slow task finished.")
return "slow_finished_in_thread"
except Exception as e:
# Catch any exception within the thread's target function
print(f"Thread: Caught exception: {e}")
exception_container.append(e)
return None # Indicate failure
def thread_target():
try:
# Run the decorated function inside the thread
res = slow_task_in_thread()
if res is not None:
result_container.append(res)
except Exception as e:
# This might catch exceptions happening *outside* the decorated function
# but still within the thread target, though less likely here.
print(f"Thread Target: Caught exception: {e}")
exception_container.append(e)
thread = threading.Thread(target=thread_target)
print("Main: Starting thread...")
thread.start()
# Wait longer than the timeout + task duration to ensure the thread finishes
# regardless of whether timeout worked or not.
thread.join(timeout=LONG_TASK_DURATION + 1)
assert len(exception_container) == 1
assert isinstance(exception_container[0], TimeoutError)
assert not result_container
def test_in_thread_timeout():
result_container = [] # Use a list to store result from thread
exception_container = [] # Use a list to store exception from thread
@timeout(seconds=TEST_TIMEOUT_SECONDS, use_signals=False)
def slow_task_in_thread():
try:
print("Thread: Starting slow task...")
time.sleep(LONG_TASK_DURATION)
print("Thread: Slow task finished.")
return "slow_finished_in_thread"
except Exception as e:
# Catch any exception within the thread's target function
print(f"Thread: Caught exception: {e}")
exception_container.append(e)
return None # Indicate failure
def thread_target():
try:
# Run the decorated function inside the thread
res = slow_task_in_thread()
if res is not None:
result_container.append(res)
except Exception as e:
# This might catch exceptions happening *outside* the decorated function
# but still within the thread target, though less likely here.
print(f"Thread Target: Caught exception: {e}")
exception_container.append(e)
thread = threading.Thread(target=thread_target)
print("Main: Starting thread...")
thread.start()
# Wait longer than the timeout + task duration to ensure the thread finishes
# regardless of whether timeout worked or not.
thread.join(timeout=LONG_TASK_DURATION + 1)
assert len(exception_container) == 1
assert isinstance(exception_container[0], TimeoutError)
assert not result_container
|