trl-mcsd / tests /test_vllm_client_server.py
ihbkaiser's picture
Implement MCSD for experimental SDPO
1fa3c6c verified
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
from types import SimpleNamespace
import pytest
from packaging.version import Version
from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer
from transformers.testing_utils import torch_device
from trl.generation.vllm_client import VLLMClient
from trl.generation.vllm_generation import extract_logprobs
from trl.import_utils import is_vllm_available
from trl.scripts.vllm_serve import chunk_list
from .testing_utils import (
TrlTestCase,
kill_process,
require_3_accelerators,
require_torch_multi_accelerator,
require_vision,
require_vllm,
)
if is_vllm_available():
import vllm
from vllm import LLM, SamplingParams
_is_vllm_ge_014 = Version(vllm.__version__) >= Version("0.14.0")
else:
_is_vllm_ge_014 = False
class TestChunkList(TrlTestCase):
def test_even_split(self):
assert chunk_list([1, 2, 3, 4, 5, 6], 2) == [[1, 2, 3], [4, 5, 6]]
def test_uneven_split(self):
assert chunk_list([1, 2, 3, 4, 5, 6], 4) == [[1, 2], [3, 4], [5], [6]]
def test_more_chunks_than_elements(self):
assert chunk_list([1, 2, 3, 4, 5, 6], 8) == [[1], [2], [3], [4], [5], [6], [], []]
def test_n_equals_len(self):
assert chunk_list([1, 2, 3], 3) == [[1], [2], [3]]
def test_n_is_1(self):
assert chunk_list([1, 2, 3], 1) == [[1, 2, 3]]
def test_single_element_list(self):
assert chunk_list([42], 2) == [[42], []]
def test_any_dtype(self):
assert chunk_list([1, "two", 3.0, {"four": 4}, ["f", "i", "v", "e"]], 2) == [
[1, "two", 3.0],
[{"four": 4}, ["f", "i", "v", "e"]],
]
class TestExtractLogprobs(TrlTestCase):
def test_extract_logprobs_sorts_by_rank_and_replaces_nan(self):
all_outputs = [
SimpleNamespace(
outputs=[
SimpleNamespace(
logprobs=[
{
11: SimpleNamespace(rank=1, logprob=-0.2),
99: SimpleNamespace(rank=0, logprob=-0.1),
42: SimpleNamespace(rank=2, logprob=float("nan")),
},
{
5: SimpleNamespace(rank=0, logprob=-1.1),
},
]
)
]
),
SimpleNamespace(
outputs=[
SimpleNamespace(
logprobs=[
{
3: SimpleNamespace(rank=1, logprob=-0.5),
7: SimpleNamespace(rank=0, logprob=-0.4),
}
]
)
]
),
]
all_logprobs, all_token_ids = extract_logprobs(all_outputs)
assert all_token_ids == [
[[99, 11, 42], [5]],
[[7, 3]],
]
assert all_logprobs == [
[[-0.1, -0.2, None], [-1.1]],
[[-0.4, -0.5]],
]
def test_extract_logprobs_returns_none_token_ids_when_logprobs_missing(self):
all_outputs = [SimpleNamespace(outputs=[SimpleNamespace(logprobs=None)])]
all_logprobs, all_token_ids = extract_logprobs(all_outputs)
assert all_logprobs is None
assert all_token_ids is None
@pytest.mark.slow
@require_torch_multi_accelerator
@require_vllm
class TestVLLMClientServer(TrlTestCase):
model_id = "Qwen/Qwen2.5-1.5B"
@classmethod
def setup_class(cls):
# We want the server to run on accelerator 1, so we set VISIBLE_DEVICES to "1"
env = os.environ.copy()
VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES"
env[VISIBLE_DEVICES] = "1" # Restrict to accelerator 1
# Start the server process
cls.server_process = subprocess.Popen(
["trl", "vllm-serve", "--model", cls.model_id], stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env
)
# Initialize the client
cls.client = VLLMClient(connection_timeout=240, host="localhost")
cls.client.init_communicator()
def test_generate(self):
prompts = ["Hello, AI!", "Tell me a joke"]
outputs = self.client.generate(prompts)
prompt_ids = outputs["prompt_ids"]
completion_ids = outputs["completion_ids"]
# Check that the outputs are lists
assert isinstance(prompt_ids, list)
assert isinstance(completion_ids, list)
# Check that the number of sequences are equal to the number of prompts
assert len(prompt_ids) == len(prompts)
assert len(completion_ids) == len(prompts)
# Check that the sequences are lists of integers
for seq in prompt_ids:
assert all(isinstance(tok, int) for tok in seq)
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)
def test_generate_with_logprobs_none(self):
outputs = self.client.generate(["Hello, AI!"], logprobs=None)
assert isinstance(outputs["prompt_ids"], list)
assert isinstance(outputs["completion_ids"], list)
assert outputs["logprobs"] is None
assert outputs["logprob_token_ids"] is None
def test_chat(self):
messages = [[{"role": "user", "content": "Hello, AI!"}], [{"role": "user", "content": "Tell me a joke"}]]
outputs = self.client.chat(messages)
prompt_ids = outputs["prompt_ids"]
completion_ids = outputs["completion_ids"]
# Check that the outputs are lists
assert isinstance(prompt_ids, list)
assert isinstance(completion_ids, list)
# Check that the number of sequences are equal to the number of messages
assert len(prompt_ids) == len(messages)
assert len(completion_ids) == len(messages)
# Check that the sequences are lists of integers
for seq in prompt_ids:
assert all(isinstance(tok, int) for tok in seq)
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)
def test_chat_with_logprobs_none(self):
outputs = self.client.chat([[{"role": "user", "content": "Hello, AI!"}]], logprobs=None)
assert isinstance(outputs["prompt_ids"], list)
assert isinstance(outputs["completion_ids"], list)
assert outputs["logprobs"] is None
assert outputs["logprob_token_ids"] is None
def test_chat_with_tools(self):
def multiply(a: int, b: int) -> int:
"""
Multiplies two integers.
Args:
a: The first integer.
b: The second integer.
Returns:
The product of the two integers.
"""
return a * b
messages = [[{"role": "user", "content": "What is 3 multiplied by 4?"}]]
outputs = self.client.chat(messages, tools=[multiply])
# Decode prompt and check that "Multiplies two integers." is in the prompt.
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
decoded_prompt = tokenizer.decode(outputs["prompt_ids"][0])
assert "Multiplies two integers." in decoded_prompt
def test_generate_with_token_ids(self):
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
prompts = ["Hello, AI!", "Tell me a joke"]
prompt_token_ids = tokenizer(prompts)["input_ids"]
outputs = self.client.generate(prompt_token_ids)
prompt_ids = outputs["prompt_ids"]
completion_ids = outputs["completion_ids"]
# Check that the outputs are lists
assert isinstance(prompt_ids, list)
assert isinstance(completion_ids, list)
# Check that the number of sequences are equal to the number of prompts
assert len(prompt_ids) == len(prompts)
assert len(completion_ids) == len(prompts)
# Check that prompt_ids match the input token IDs
assert prompt_ids == prompt_token_ids
# Check that the sequences are lists of integers
for seq in prompt_ids:
assert all(isinstance(tok, int) for tok in seq)
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)
def test_generate_with_params(self):
prompts = ["Hello, AI!", "Tell me a joke"]
completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[
"completion_ids"
]
# Check that the output is a list
assert isinstance(completion_ids, list)
# Check that the number of generated sequences is 2 times the number of prompts
assert len(completion_ids) == 2 * len(prompts)
# Check that the generated sequences are lists of integers
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)
# Check that the length of the generated sequences is less than or equal to 32
for seq in completion_ids:
assert len(seq) <= 32
def test_update_model_params(self):
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device)
self.client.update_model_params(model)
def test_reset_prefix_cache(self):
# Test resetting the prefix cache
self.client.reset_prefix_cache()
@pytest.mark.xfail(reason="Importing `bitsandbytes` causes issues, see vllm-project/vllm#32793")
def test_logprobs_match_with_non_default_sampling(self):
prompts = ["Hello, AI!", "Tell me a joke"]
# Use non-default sampling parameters (especially temperature) to ensure vLLM applies logprob processing. With
# default sampling, raw and processed logprobs are identical, so mismatches would not be detected.
temperature = 0.7
repetition_penalty = 1.05
top_p = 0.9
max_tokens = 8
seed = 1234
num_logprobs = 5
server_outputs = self.client.generate(
prompts,
temperature=temperature,
repetition_penalty=repetition_penalty,
top_p=top_p,
max_tokens=max_tokens,
logprobs=num_logprobs,
generation_kwargs={"seed": seed},
)
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
llm = LLM(
model=self.model_id,
tensor_parallel_size=1,
gpu_memory_utilization=0.2,
max_model_len=128,
logprobs_mode="processed_logprobs",
)
sampling_params = SamplingParams(
temperature=temperature,
repetition_penalty=repetition_penalty,
top_p=top_p,
max_tokens=max_tokens,
logprobs=num_logprobs,
seed=seed,
)
colocate_outputs = llm.generate(prompts, sampling_params=sampling_params, use_tqdm=False)
colocate_prompt_ids = [output.prompt_token_ids for output in colocate_outputs]
colocate_completion_ids = [
list(output.token_ids) for outputs in colocate_outputs for output in outputs.outputs
]
colocate_logprobs, colocate_logprob_token_ids = extract_logprobs(colocate_outputs)
# Generation correctness: prompt and completion IDs match between server and colocate
assert server_outputs["prompt_ids"] == colocate_prompt_ids
assert server_outputs["completion_ids"] == colocate_completion_ids
server_logprobs = server_outputs["logprobs"]
server_logprob_token_ids = server_outputs["logprob_token_ids"]
# Shape: both should be (num_sequences, seq_len, num_logprobs) with multiple logprobs per token
assert len(server_logprobs) == len(prompts)
assert len(server_logprob_token_ids) == len(prompts)
for seq_lps in server_logprobs:
for token_lps in seq_lps:
assert len(token_lps) > 1, "Expected multiple logprobs per token when logprobs > 0"
# Value correctness: server extraction matches colocate extraction via extract_logprobs
assert server_logprob_token_ids == colocate_logprob_token_ids
for server_seq, colocate_seq in zip(server_logprobs, colocate_logprobs, strict=True):
assert len(server_seq) == len(colocate_seq)
for server_token_lps, colocate_token_lps in zip(server_seq, colocate_seq, strict=True):
assert server_token_lps == pytest.approx(colocate_token_lps, rel=1e-6, abs=1e-6)
# Ordering: logprobs at each position should be sorted descending
for seq_lps in server_logprobs:
for token_lps in seq_lps:
assert token_lps == sorted(token_lps, reverse=True), "Logprobs should be sorted descending"
# Sampled token presence: the actual completion token should appear in the logprob token IDs
for seq_idx, (completion_seq, token_ids_seq) in enumerate(
zip(server_outputs["completion_ids"], server_logprob_token_ids, strict=True)
):
for pos, (sampled_id, lp_ids) in enumerate(zip(completion_seq, token_ids_seq, strict=True)):
assert sampled_id in lp_ids, (
f"Sampled token {sampled_id} not found in logprob token IDs {lp_ids} "
f"at sequence {seq_idx}, position {pos}"
)
@classmethod
def teardown_class(cls):
# Close the client
cls.client.close_communicator()
# vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to
# kill the server process and its children explicitly.
kill_process(cls.server_process)
# Same as above but using base_url to instantiate the client.
@pytest.mark.slow
@require_torch_multi_accelerator
@require_vllm
class TestVLLMClientServerBaseURL(TrlTestCase):
model_id = "Qwen/Qwen2.5-1.5B"
@classmethod
def setup_class(cls):
# We want the server to run on accelerator 1, so we set VISIBLE_DEVICES to "1"
env = os.environ.copy()
VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES"
env[VISIBLE_DEVICES] = "1" # Restrict to accelerator 1
# Start the server process
cls.server_process = subprocess.Popen(
["trl", "vllm-serve", "--model", cls.model_id], stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env
)
# Initialize the client
cls.client = VLLMClient(base_url="http://localhost:8000", connection_timeout=240)
cls.client.init_communicator()
def test_generate(self):
prompts = ["Hello, AI!", "Tell me a joke"]
outputs = self.client.generate(prompts)
prompt_ids = outputs["prompt_ids"]
completion_ids = outputs["completion_ids"]
# Check that the outputs are lists
assert isinstance(prompt_ids, list)
assert isinstance(completion_ids, list)
# Check that the number of sequences are equal to the number of prompts
assert len(prompt_ids) == len(prompts)
assert len(completion_ids) == len(prompts)
# Check that the sequences are lists of integers
for seq in prompt_ids:
assert all(isinstance(tok, int) for tok in seq)
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)
def test_generate_with_logprobs_none(self):
outputs = self.client.generate(["Hello, AI!"], logprobs=None)
assert isinstance(outputs["prompt_ids"], list)
assert isinstance(outputs["completion_ids"], list)
assert outputs["logprobs"] is None
assert outputs["logprob_token_ids"] is None
def test_chat(self):
messages = [[{"role": "user", "content": "Hello, AI!"}], [{"role": "user", "content": "Tell me a joke"}]]
outputs = self.client.chat(messages)
prompt_ids = outputs["prompt_ids"]
completion_ids = outputs["completion_ids"]
# Check that the outputs are lists
assert isinstance(prompt_ids, list)
assert isinstance(completion_ids, list)
# Check that the number of sequences are equal to the number of messages
assert len(prompt_ids) == len(messages)
assert len(completion_ids) == len(messages)
# Check that the sequences are lists of integers
for seq in prompt_ids:
assert all(isinstance(tok, int) for tok in seq)
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)
def test_chat_with_logprobs_none(self):
outputs = self.client.chat([[{"role": "user", "content": "Hello, AI!"}]], logprobs=None)
assert isinstance(outputs["prompt_ids"], list)
assert isinstance(outputs["completion_ids"], list)
assert outputs["logprobs"] is None
assert outputs["logprob_token_ids"] is None
def test_chat_with_tools(self):
def multiply(a: int, b: int) -> int:
"""
Multiplies two integers.
Args:
a: The first integer.
b: The second integer.
Returns:
The product of the two integers.
"""
return a * b
messages = [[{"role": "user", "content": "What is 3 multiplied by 4?"}]]
outputs = self.client.chat(messages, tools=[multiply])
# Decode prompt and check that "Multiplies two integers." is in the prompt.
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
decoded_prompt = tokenizer.decode(outputs["prompt_ids"][0])
assert "Multiplies two integers." in decoded_prompt
def test_generate_with_token_ids(self):
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
prompts = ["Hello, AI!", "Tell me a joke"]
prompt_token_ids = tokenizer(prompts)["input_ids"]
outputs = self.client.generate(prompt_token_ids)
prompt_ids = outputs["prompt_ids"]
completion_ids = outputs["completion_ids"]
# Check that the outputs are lists
assert isinstance(prompt_ids, list)
assert isinstance(completion_ids, list)
# Check that the number of sequences are equal to the number of prompts
assert len(prompt_ids) == len(prompts)
assert len(completion_ids) == len(prompts)
# Check that prompt_ids match the input token IDs
assert prompt_ids == prompt_token_ids
# Check that the sequences are lists of integers
for seq in prompt_ids:
assert all(isinstance(tok, int) for tok in seq)
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)
def test_generate_with_params(self):
prompts = ["Hello, AI!", "Tell me a joke"]
completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[
"completion_ids"
]
# Check that the output is a list
assert isinstance(completion_ids, list)
# Check that the number of generated sequences is 2 times the number of prompts
assert len(completion_ids) == 2 * len(prompts)
# Check that the generated sequences are lists of integers
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)
# Check that the length of the generated sequences is less than or equal to 32
for seq in completion_ids:
assert len(seq) <= 32
def test_update_model_params(self):
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device)
self.client.update_model_params(model)
def test_reset_prefix_cache(self):
# Test resetting the prefix cache
self.client.reset_prefix_cache()
@classmethod
def teardown_class(cls):
# Close the client
cls.client.close_communicator()
# vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to
# kill the server process and its children explicitly.
kill_process(cls.server_process)
@pytest.mark.slow
@require_3_accelerators
@require_vllm
class TestVLLMClientServerTP(TrlTestCase):
model_id = "Qwen/Qwen2.5-1.5B"
@classmethod
def setup_class(cls):
# We want the server to run on accelerator 1 and 2, so we set VISIBLE_DEVICES to "1,2"
env = os.environ.copy()
VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES"
env[VISIBLE_DEVICES] = "1,2" # Restrict to accelerator 1 and 2
# Start the server process
cls.server_process = subprocess.Popen(
["trl", "vllm-serve", "--model", cls.model_id, "--tensor_parallel_size", "2"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=env,
)
# Initialize the client
cls.client = VLLMClient(connection_timeout=240, host="localhost")
cls.client.init_communicator()
def test_generate(self):
prompts = ["Hello, AI!", "Tell me a joke"]
outputs = self.client.generate(prompts)
prompt_ids = outputs["prompt_ids"]
completion_ids = outputs["completion_ids"]
# Check that the outputs are lists
assert isinstance(prompt_ids, list)
assert isinstance(completion_ids, list)
# Check that the number of sequences are equal to the number of prompts
assert len(prompt_ids) == len(prompts)
assert len(completion_ids) == len(prompts)
# Check that the sequences are lists of integers
for seq in prompt_ids:
assert all(isinstance(tok, int) for tok in seq)
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)
def test_generate_with_logprobs_none(self):
outputs = self.client.generate(["Hello, AI!"], logprobs=None)
assert isinstance(outputs["prompt_ids"], list)
assert isinstance(outputs["completion_ids"], list)
assert outputs["logprobs"] is None
assert outputs["logprob_token_ids"] is None
def test_chat(self):
messages = [[{"role": "user", "content": "Hello, AI!"}], [{"role": "user", "content": "Tell me a joke"}]]
outputs = self.client.chat(messages)
prompt_ids = outputs["prompt_ids"]
completion_ids = outputs["completion_ids"]
# Check that the outputs are lists
assert isinstance(prompt_ids, list)
assert isinstance(completion_ids, list)
# Check that the number of sequences are equal to the number of messages
assert len(prompt_ids) == len(messages)
assert len(completion_ids) == len(messages)
# Check that the sequences are lists of integers
for seq in prompt_ids:
assert all(isinstance(tok, int) for tok in seq)
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)
def test_chat_with_logprobs_none(self):
outputs = self.client.chat([[{"role": "user", "content": "Hello, AI!"}]], logprobs=None)
assert isinstance(outputs["prompt_ids"], list)
assert isinstance(outputs["completion_ids"], list)
assert outputs["logprobs"] is None
assert outputs["logprob_token_ids"] is None
def test_chat_with_tools(self):
def multiply(a: int, b: int) -> int:
"""
Multiplies two integers.
Args:
a: The first integer.
b: The second integer.
Returns:
The product of the two integers.
"""
return a * b
messages = [[{"role": "user", "content": "What is 3 multiplied by 4?"}]]
outputs = self.client.chat(messages, tools=[multiply])
# Decode prompt and check that "Multiplies two integers." is in the prompt.
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
decoded_prompt = tokenizer.decode(outputs["prompt_ids"][0])
assert "Multiplies two integers." in decoded_prompt
def test_generate_with_token_ids(self):
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
prompts = ["Hello, AI!", "Tell me a joke"]
prompt_token_ids = tokenizer(prompts)["input_ids"]
outputs = self.client.generate(prompt_token_ids)
prompt_ids = outputs["prompt_ids"]
completion_ids = outputs["completion_ids"]
# Check that the outputs are lists
assert isinstance(prompt_ids, list)
assert isinstance(completion_ids, list)
# Check that the number of sequences are equal to the number of prompts
assert len(prompt_ids) == len(prompts)
assert len(completion_ids) == len(prompts)
# Check that prompt_ids match the input token IDs
assert prompt_ids == prompt_token_ids
# Check that the sequences are lists of integers
for seq in prompt_ids:
assert all(isinstance(tok, int) for tok in seq)
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)
def test_generate_with_params(self):
prompts = ["Hello, AI!", "Tell me a joke"]
completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[
"completion_ids"
]
# Check that the output is a list
assert isinstance(completion_ids, list)
# Check that the number of generated sequences is 2 times the number of prompts
assert len(completion_ids) == 2 * len(prompts)
# Check that the generated sequences are lists of integers
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)
# Check that the length of the generated sequences is less than or equal to 32
for seq in completion_ids:
assert len(seq) <= 32
def test_update_model_params(self):
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device)
self.client.update_model_params(model)
def test_reset_prefix_cache(self):
# Test resetting the prefix cache
self.client.reset_prefix_cache()
@classmethod
def teardown_class(cls):
# Close the client
cls.client.close_communicator()
# vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to
# kill the server process and its children explicitly.
kill_process(cls.server_process)
@pytest.mark.slow
@pytest.mark.skipif(
_is_vllm_ge_014,
reason="Skipping DP server test for vLLM>=0.14.0 (PR vllm#30739: DP for non-MoE/dense models no longer supported).",
)
@require_3_accelerators
@require_vllm
class TestVLLMClientServerDP(TrlTestCase):
model_id = "Qwen/Qwen2.5-1.5B"
@classmethod
def setup_class(cls):
# We want the server to run on accelerator 1 and 2, so we set VISIBLE_DEVICES to "1,2"
env = os.environ.copy()
VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES"
env[VISIBLE_DEVICES] = "1,2" # Restrict to accelerator 1 and 2
# Start the server process
cls.server_process = subprocess.Popen(
["trl", "vllm-serve", "--model", cls.model_id, "--data_parallel_size", "2"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=env,
)
# Initialize the client
cls.client = VLLMClient(connection_timeout=240, host="localhost")
cls.client.init_communicator()
def test_generate(self):
prompts = ["Hello, AI!", "Tell me a joke"]
outputs = self.client.generate(prompts)
prompt_ids = outputs["prompt_ids"]
completion_ids = outputs["completion_ids"]
# Check that the outputs are lists
assert isinstance(prompt_ids, list)
assert isinstance(completion_ids, list)
# Check that the number of sequences are equal to the number of prompts
assert len(prompt_ids) == len(prompts)
assert len(completion_ids) == len(prompts)
# Check that the sequences are lists of integers
for seq in prompt_ids:
assert all(isinstance(tok, int) for tok in seq)
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)
def test_generate_with_logprobs_none(self):
outputs = self.client.generate(["Hello, AI!"], logprobs=None)
assert isinstance(outputs["prompt_ids"], list)
assert isinstance(outputs["completion_ids"], list)
assert outputs["logprobs"] is None
assert outputs["logprob_token_ids"] is None
def test_chat(self):
messages = [[{"role": "user", "content": "Hello, AI!"}], [{"role": "user", "content": "Tell me a joke"}]]
outputs = self.client.chat(messages)
prompt_ids = outputs["prompt_ids"]
completion_ids = outputs["completion_ids"]
# Check that the outputs are lists
assert isinstance(prompt_ids, list)
assert isinstance(completion_ids, list)
# Check that the number of sequences are equal to the number of messages
assert len(prompt_ids) == len(messages)
assert len(completion_ids) == len(messages)
# Check that the sequences are lists of integers
for seq in prompt_ids:
assert all(isinstance(tok, int) for tok in seq)
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)
def test_chat_with_logprobs_none(self):
outputs = self.client.chat([[{"role": "user", "content": "Hello, AI!"}]], logprobs=None)
assert isinstance(outputs["prompt_ids"], list)
assert isinstance(outputs["completion_ids"], list)
assert outputs["logprobs"] is None
assert outputs["logprob_token_ids"] is None
def test_chat_with_tools(self):
def multiply(a: int, b: int) -> int:
"""
Multiplies two integers.
Args:
a: The first integer.
b: The second integer.
Returns:
The product of the two integers.
"""
return a * b
messages = [[{"role": "user", "content": "What is 3 multiplied by 4?"}]]
outputs = self.client.chat(messages, tools=[multiply])
# Decode prompt and check that "Multiplies two integers." is in the prompt.
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
decoded_prompt = tokenizer.decode(outputs["prompt_ids"][0])
assert "Multiplies two integers." in decoded_prompt
def test_generate_with_token_ids(self):
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
prompts = ["Hello, AI!", "Tell me a joke"]
prompt_token_ids = tokenizer(prompts)["input_ids"]
outputs = self.client.generate(prompt_token_ids)
prompt_ids = outputs["prompt_ids"]
completion_ids = outputs["completion_ids"]
# Check that the outputs are lists
assert isinstance(prompt_ids, list)
assert isinstance(completion_ids, list)
# Check that the number of sequences are equal to the number of prompts
assert len(prompt_ids) == len(prompts)
assert len(completion_ids) == len(prompts)
# Check that prompt_ids match the input token IDs
assert prompt_ids == prompt_token_ids
# Check that the sequences are lists of integers
for seq in prompt_ids:
assert all(isinstance(tok, int) for tok in seq)
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)
def test_generate_with_params(self):
prompts = ["Hello, AI!", "Tell me a joke"]
completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[
"completion_ids"
]
# Check that the output is a list
assert isinstance(completion_ids, list)
# Check that the number of generated sequences is 2 times the number of prompts
assert len(completion_ids) == 2 * len(prompts)
# Check that the generated sequences are lists of integers
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)
# Check that the length of the generated sequences is less than or equal to 32
for seq in completion_ids:
assert len(seq) <= 32
def test_update_model_params(self):
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device)
self.client.update_model_params(model)
def test_reset_prefix_cache(self):
# Test resetting the prefix cache
self.client.reset_prefix_cache()
@classmethod
def teardown_class(cls):
# Close the client
cls.client.close_communicator()
# vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to
# kill the server process and its children explicitly.
kill_process(cls.server_process)
@pytest.mark.slow
@require_torch_multi_accelerator
@require_vllm
class TestVLLMClientServerDeviceParameter(TrlTestCase):
"""Test the device parameter functionality in init_communicator."""
model_id = "Qwen/Qwen2.5-1.5B"
@classmethod
def setup_class(cls):
# We want the server to run on accelerator 1, so we set VISIBLE_DEVICES to "1"
env = os.environ.copy()
VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES"
env[VISIBLE_DEVICES] = "1" # Restrict to accelerator 1
# Start the server process
cls.server_process = subprocess.Popen(
["trl", "vllm-serve", "--model", cls.model_id], stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env
)
def test_init_communicator_with_device_int(self):
"""Test init_communicator with integer device parameter."""
client = VLLMClient(connection_timeout=240, host="localhost")
client.init_communicator(device=0) # Explicitly specify device 0
# Test basic functionality
prompts = ["Hello, AI!"]
outputs = client.generate(prompts)
prompt_ids = outputs["prompt_ids"]
completion_ids = outputs["completion_ids"]
assert isinstance(prompt_ids, list)
assert len(prompt_ids) == len(prompts)
assert isinstance(completion_ids, list)
assert len(completion_ids) == len(prompts)
client.close_communicator()
def test_init_communicator_with_device_string(self):
"""Test init_communicator with string device parameter."""
client = VLLMClient(connection_timeout=240, host="localhost")
client.init_communicator(device=0) # Explicitly specify device as string
# Test basic functionality
prompts = ["Hello, AI!"]
outputs = client.generate(prompts)["completion_ids"]
assert isinstance(outputs, list)
assert len(outputs) == len(prompts)
client.close_communicator()
def test_init_communicator_with_torch_device(self):
"""Test init_communicator with torch.device object."""
import torch
client = VLLMClient(connection_timeout=240, host="localhost")
device = torch.device(0)
client.init_communicator(device=device) # Explicitly specify torch.device object
# Test basic functionality
prompts = ["Hello, AI!"]
outputs = client.generate(prompts)["completion_ids"]
assert isinstance(outputs, list)
assert len(outputs) == len(prompts)
client.close_communicator()
@classmethod
def teardown_class(cls):
# vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to
# kill the server process and its children explicitly.
kill_process(cls.server_process)
@pytest.mark.slow
@require_vllm
@require_vision
class TestVLLMClientServerVLM(TrlTestCase):
model_id = "Qwen/Qwen2.5-VL-3B-Instruct"
@classmethod
def setup_class(cls):
# Start the server process
cls.server_process = subprocess.Popen(
["trl", "vllm-serve", "--model", cls.model_id], stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
# Initialize the client (no communicator needed for generation-only tests)
cls.client = VLLMClient(connection_timeout=240, host="localhost")
def test_generate_with_token_ids_and_image(self):
from PIL import Image
processor = AutoProcessor.from_pretrained(self.model_id)
image1 = Image.new("RGB", (64, 64), color="red")
image2 = Image.new("RGB", (64, 64), color="blue")
image3 = Image.new("RGB", (64, 64), color="green")
messages = [
[
{
"role": "user",
"content": [
{"type": "image", "image": image1},
{"type": "image", "image": image2},
{"type": "text", "text": "What are the differences between these two images?"},
],
}
],
[
{
"role": "user",
"content": [
{"type": "image", "image": image3},
{"type": "text", "text": "What is the color of this image?"},
],
}
],
]
prompt_token_ids = processor.apply_chat_template(
conversation=messages, tokenize=True, add_generation_prompt=True
)
outputs = self.client.generate(prompt_token_ids, images=[[image1, image2], [image3]], max_tokens=64)
prompt_ids = outputs["prompt_ids"]
completion_ids = outputs["completion_ids"]
assert len(prompt_ids) == 2
assert len(completion_ids) == 2
assert all(isinstance(tok, int) for tok in prompt_ids[0])
assert all(isinstance(tok, int) for tok in completion_ids[0])
def test_generate_with_token_ids_mixed_images(self):
"""Test a batch where one prompt has an image and the other does not."""
from PIL import Image
processor = AutoProcessor.from_pretrained(self.model_id)
image = Image.new("RGB", (64, 64), color="red")
messages = [
[
{
"role": "user",
"content": [{"type": "image", "image": image}, {"type": "text", "text": "Describe this image."}],
}
],
[
{
"role": "user",
"content": [{"type": "text", "text": "What is 1+1?"}],
}
],
]
prompt_token_ids = processor.apply_chat_template(
conversation=messages, tokenize=True, add_generation_prompt=True
)
outputs = self.client.generate(prompt_token_ids, images=[[image], None], max_tokens=64)
prompt_ids = outputs["prompt_ids"]
completion_ids = outputs["completion_ids"]
assert len(prompt_ids) == 2
assert len(completion_ids) == 2
assert all(isinstance(tok, int) for tok in prompt_ids[0])
assert all(isinstance(tok, int) for tok in prompt_ids[1])
assert all(isinstance(tok, int) for tok in completion_ids[0])
assert all(isinstance(tok, int) for tok in completion_ids[1])
@classmethod
def teardown_class(cls):
kill_process(cls.server_process)