|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import os
|
| import subprocess
|
| from types import SimpleNamespace
|
|
|
| import pytest
|
| from packaging.version import Version
|
| from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer
|
| from transformers.testing_utils import torch_device
|
|
|
| from trl.generation.vllm_client import VLLMClient
|
| from trl.generation.vllm_generation import extract_logprobs
|
| from trl.import_utils import is_vllm_available
|
| from trl.scripts.vllm_serve import chunk_list
|
|
|
| from .testing_utils import (
|
| TrlTestCase,
|
| kill_process,
|
| require_3_accelerators,
|
| require_torch_multi_accelerator,
|
| require_vision,
|
| require_vllm,
|
| )
|
|
|
|
|
| if is_vllm_available():
|
| import vllm
|
| from vllm import LLM, SamplingParams
|
|
|
| _is_vllm_ge_014 = Version(vllm.__version__) >= Version("0.14.0")
|
| else:
|
| _is_vllm_ge_014 = False
|
|
|
|
|
| class TestChunkList(TrlTestCase):
|
| def test_even_split(self):
|
| assert chunk_list([1, 2, 3, 4, 5, 6], 2) == [[1, 2, 3], [4, 5, 6]]
|
|
|
| def test_uneven_split(self):
|
| assert chunk_list([1, 2, 3, 4, 5, 6], 4) == [[1, 2], [3, 4], [5], [6]]
|
|
|
| def test_more_chunks_than_elements(self):
|
| assert chunk_list([1, 2, 3, 4, 5, 6], 8) == [[1], [2], [3], [4], [5], [6], [], []]
|
|
|
| def test_n_equals_len(self):
|
| assert chunk_list([1, 2, 3], 3) == [[1], [2], [3]]
|
|
|
| def test_n_is_1(self):
|
| assert chunk_list([1, 2, 3], 1) == [[1, 2, 3]]
|
|
|
| def test_single_element_list(self):
|
| assert chunk_list([42], 2) == [[42], []]
|
|
|
| def test_any_dtype(self):
|
| assert chunk_list([1, "two", 3.0, {"four": 4}, ["f", "i", "v", "e"]], 2) == [
|
| [1, "two", 3.0],
|
| [{"four": 4}, ["f", "i", "v", "e"]],
|
| ]
|
|
|
|
|
| class TestExtractLogprobs(TrlTestCase):
|
| def test_extract_logprobs_sorts_by_rank_and_replaces_nan(self):
|
| all_outputs = [
|
| SimpleNamespace(
|
| outputs=[
|
| SimpleNamespace(
|
| logprobs=[
|
| {
|
| 11: SimpleNamespace(rank=1, logprob=-0.2),
|
| 99: SimpleNamespace(rank=0, logprob=-0.1),
|
| 42: SimpleNamespace(rank=2, logprob=float("nan")),
|
| },
|
| {
|
| 5: SimpleNamespace(rank=0, logprob=-1.1),
|
| },
|
| ]
|
| )
|
| ]
|
| ),
|
| SimpleNamespace(
|
| outputs=[
|
| SimpleNamespace(
|
| logprobs=[
|
| {
|
| 3: SimpleNamespace(rank=1, logprob=-0.5),
|
| 7: SimpleNamespace(rank=0, logprob=-0.4),
|
| }
|
| ]
|
| )
|
| ]
|
| ),
|
| ]
|
|
|
| all_logprobs, all_token_ids = extract_logprobs(all_outputs)
|
|
|
| assert all_token_ids == [
|
| [[99, 11, 42], [5]],
|
| [[7, 3]],
|
| ]
|
| assert all_logprobs == [
|
| [[-0.1, -0.2, None], [-1.1]],
|
| [[-0.4, -0.5]],
|
| ]
|
|
|
| def test_extract_logprobs_returns_none_token_ids_when_logprobs_missing(self):
|
| all_outputs = [SimpleNamespace(outputs=[SimpleNamespace(logprobs=None)])]
|
|
|
| all_logprobs, all_token_ids = extract_logprobs(all_outputs)
|
|
|
| assert all_logprobs is None
|
| assert all_token_ids is None
|
|
|
|
|
| @pytest.mark.slow
|
| @require_torch_multi_accelerator
|
| @require_vllm
|
| class TestVLLMClientServer(TrlTestCase):
|
| model_id = "Qwen/Qwen2.5-1.5B"
|
|
|
| @classmethod
|
| def setup_class(cls):
|
|
|
| env = os.environ.copy()
|
| VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES"
|
| env[VISIBLE_DEVICES] = "1"
|
|
|
|
|
| cls.server_process = subprocess.Popen(
|
| ["trl", "vllm-serve", "--model", cls.model_id], stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env
|
| )
|
|
|
|
|
| cls.client = VLLMClient(connection_timeout=240, host="localhost")
|
| cls.client.init_communicator()
|
|
|
| def test_generate(self):
|
| prompts = ["Hello, AI!", "Tell me a joke"]
|
| outputs = self.client.generate(prompts)
|
| prompt_ids = outputs["prompt_ids"]
|
| completion_ids = outputs["completion_ids"]
|
|
|
|
|
| assert isinstance(prompt_ids, list)
|
| assert isinstance(completion_ids, list)
|
|
|
|
|
| assert len(prompt_ids) == len(prompts)
|
| assert len(completion_ids) == len(prompts)
|
|
|
|
|
| for seq in prompt_ids:
|
| assert all(isinstance(tok, int) for tok in seq)
|
| for seq in completion_ids:
|
| assert all(isinstance(tok, int) for tok in seq)
|
|
|
| def test_generate_with_logprobs_none(self):
|
| outputs = self.client.generate(["Hello, AI!"], logprobs=None)
|
|
|
| assert isinstance(outputs["prompt_ids"], list)
|
| assert isinstance(outputs["completion_ids"], list)
|
| assert outputs["logprobs"] is None
|
| assert outputs["logprob_token_ids"] is None
|
|
|
| def test_chat(self):
|
| messages = [[{"role": "user", "content": "Hello, AI!"}], [{"role": "user", "content": "Tell me a joke"}]]
|
| outputs = self.client.chat(messages)
|
| prompt_ids = outputs["prompt_ids"]
|
| completion_ids = outputs["completion_ids"]
|
|
|
|
|
| assert isinstance(prompt_ids, list)
|
| assert isinstance(completion_ids, list)
|
|
|
|
|
| assert len(prompt_ids) == len(messages)
|
| assert len(completion_ids) == len(messages)
|
|
|
|
|
| for seq in prompt_ids:
|
| assert all(isinstance(tok, int) for tok in seq)
|
| for seq in completion_ids:
|
| assert all(isinstance(tok, int) for tok in seq)
|
|
|
| def test_chat_with_logprobs_none(self):
|
| outputs = self.client.chat([[{"role": "user", "content": "Hello, AI!"}]], logprobs=None)
|
|
|
| assert isinstance(outputs["prompt_ids"], list)
|
| assert isinstance(outputs["completion_ids"], list)
|
| assert outputs["logprobs"] is None
|
| assert outputs["logprob_token_ids"] is None
|
|
|
| def test_chat_with_tools(self):
|
| def multiply(a: int, b: int) -> int:
|
| """
|
| Multiplies two integers.
|
|
|
| Args:
|
| a: The first integer.
|
| b: The second integer.
|
|
|
| Returns:
|
| The product of the two integers.
|
| """
|
| return a * b
|
|
|
| messages = [[{"role": "user", "content": "What is 3 multiplied by 4?"}]]
|
| outputs = self.client.chat(messages, tools=[multiply])
|
|
|
|
|
| tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
| decoded_prompt = tokenizer.decode(outputs["prompt_ids"][0])
|
| assert "Multiplies two integers." in decoded_prompt
|
|
|
| def test_generate_with_token_ids(self):
|
| tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
| prompts = ["Hello, AI!", "Tell me a joke"]
|
| prompt_token_ids = tokenizer(prompts)["input_ids"]
|
| outputs = self.client.generate(prompt_token_ids)
|
| prompt_ids = outputs["prompt_ids"]
|
| completion_ids = outputs["completion_ids"]
|
|
|
|
|
| assert isinstance(prompt_ids, list)
|
| assert isinstance(completion_ids, list)
|
|
|
|
|
| assert len(prompt_ids) == len(prompts)
|
| assert len(completion_ids) == len(prompts)
|
|
|
|
|
| assert prompt_ids == prompt_token_ids
|
|
|
|
|
| for seq in prompt_ids:
|
| assert all(isinstance(tok, int) for tok in seq)
|
| for seq in completion_ids:
|
| assert all(isinstance(tok, int) for tok in seq)
|
|
|
| def test_generate_with_params(self):
|
| prompts = ["Hello, AI!", "Tell me a joke"]
|
| completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[
|
| "completion_ids"
|
| ]
|
|
|
|
|
| assert isinstance(completion_ids, list)
|
|
|
|
|
| assert len(completion_ids) == 2 * len(prompts)
|
|
|
|
|
| for seq in completion_ids:
|
| assert all(isinstance(tok, int) for tok in seq)
|
|
|
|
|
| for seq in completion_ids:
|
| assert len(seq) <= 32
|
|
|
| def test_update_model_params(self):
|
| model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device)
|
| self.client.update_model_params(model)
|
|
|
| def test_reset_prefix_cache(self):
|
|
|
| self.client.reset_prefix_cache()
|
|
|
| @pytest.mark.xfail(reason="Importing `bitsandbytes` causes issues, see vllm-project/vllm#32793")
|
| def test_logprobs_match_with_non_default_sampling(self):
|
| prompts = ["Hello, AI!", "Tell me a joke"]
|
|
|
|
|
| temperature = 0.7
|
| repetition_penalty = 1.05
|
| top_p = 0.9
|
| max_tokens = 8
|
| seed = 1234
|
| num_logprobs = 5
|
|
|
| server_outputs = self.client.generate(
|
| prompts,
|
| temperature=temperature,
|
| repetition_penalty=repetition_penalty,
|
| top_p=top_p,
|
| max_tokens=max_tokens,
|
| logprobs=num_logprobs,
|
| generation_kwargs={"seed": seed},
|
| )
|
| os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
| llm = LLM(
|
| model=self.model_id,
|
| tensor_parallel_size=1,
|
| gpu_memory_utilization=0.2,
|
| max_model_len=128,
|
| logprobs_mode="processed_logprobs",
|
| )
|
|
|
| sampling_params = SamplingParams(
|
| temperature=temperature,
|
| repetition_penalty=repetition_penalty,
|
| top_p=top_p,
|
| max_tokens=max_tokens,
|
| logprobs=num_logprobs,
|
| seed=seed,
|
| )
|
| colocate_outputs = llm.generate(prompts, sampling_params=sampling_params, use_tqdm=False)
|
| colocate_prompt_ids = [output.prompt_token_ids for output in colocate_outputs]
|
| colocate_completion_ids = [
|
| list(output.token_ids) for outputs in colocate_outputs for output in outputs.outputs
|
| ]
|
| colocate_logprobs, colocate_logprob_token_ids = extract_logprobs(colocate_outputs)
|
|
|
|
|
| assert server_outputs["prompt_ids"] == colocate_prompt_ids
|
| assert server_outputs["completion_ids"] == colocate_completion_ids
|
|
|
| server_logprobs = server_outputs["logprobs"]
|
| server_logprob_token_ids = server_outputs["logprob_token_ids"]
|
|
|
|
|
| assert len(server_logprobs) == len(prompts)
|
| assert len(server_logprob_token_ids) == len(prompts)
|
| for seq_lps in server_logprobs:
|
| for token_lps in seq_lps:
|
| assert len(token_lps) > 1, "Expected multiple logprobs per token when logprobs > 0"
|
|
|
|
|
| assert server_logprob_token_ids == colocate_logprob_token_ids
|
| for server_seq, colocate_seq in zip(server_logprobs, colocate_logprobs, strict=True):
|
| assert len(server_seq) == len(colocate_seq)
|
| for server_token_lps, colocate_token_lps in zip(server_seq, colocate_seq, strict=True):
|
| assert server_token_lps == pytest.approx(colocate_token_lps, rel=1e-6, abs=1e-6)
|
|
|
|
|
| for seq_lps in server_logprobs:
|
| for token_lps in seq_lps:
|
| assert token_lps == sorted(token_lps, reverse=True), "Logprobs should be sorted descending"
|
|
|
|
|
| for seq_idx, (completion_seq, token_ids_seq) in enumerate(
|
| zip(server_outputs["completion_ids"], server_logprob_token_ids, strict=True)
|
| ):
|
| for pos, (sampled_id, lp_ids) in enumerate(zip(completion_seq, token_ids_seq, strict=True)):
|
| assert sampled_id in lp_ids, (
|
| f"Sampled token {sampled_id} not found in logprob token IDs {lp_ids} "
|
| f"at sequence {seq_idx}, position {pos}"
|
| )
|
|
|
| @classmethod
|
| def teardown_class(cls):
|
|
|
| cls.client.close_communicator()
|
|
|
|
|
|
|
| kill_process(cls.server_process)
|
|
|
|
|
|
|
| @pytest.mark.slow
|
| @require_torch_multi_accelerator
|
| @require_vllm
|
| class TestVLLMClientServerBaseURL(TrlTestCase):
|
| model_id = "Qwen/Qwen2.5-1.5B"
|
|
|
| @classmethod
|
| def setup_class(cls):
|
|
|
| env = os.environ.copy()
|
| VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES"
|
| env[VISIBLE_DEVICES] = "1"
|
|
|
|
|
| cls.server_process = subprocess.Popen(
|
| ["trl", "vllm-serve", "--model", cls.model_id], stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env
|
| )
|
|
|
|
|
| cls.client = VLLMClient(base_url="http://localhost:8000", connection_timeout=240)
|
| cls.client.init_communicator()
|
|
|
| def test_generate(self):
|
| prompts = ["Hello, AI!", "Tell me a joke"]
|
| outputs = self.client.generate(prompts)
|
| prompt_ids = outputs["prompt_ids"]
|
| completion_ids = outputs["completion_ids"]
|
|
|
|
|
| assert isinstance(prompt_ids, list)
|
| assert isinstance(completion_ids, list)
|
|
|
|
|
| assert len(prompt_ids) == len(prompts)
|
| assert len(completion_ids) == len(prompts)
|
|
|
|
|
| for seq in prompt_ids:
|
| assert all(isinstance(tok, int) for tok in seq)
|
| for seq in completion_ids:
|
| assert all(isinstance(tok, int) for tok in seq)
|
|
|
| def test_generate_with_logprobs_none(self):
|
| outputs = self.client.generate(["Hello, AI!"], logprobs=None)
|
|
|
| assert isinstance(outputs["prompt_ids"], list)
|
| assert isinstance(outputs["completion_ids"], list)
|
| assert outputs["logprobs"] is None
|
| assert outputs["logprob_token_ids"] is None
|
|
|
| def test_chat(self):
|
| messages = [[{"role": "user", "content": "Hello, AI!"}], [{"role": "user", "content": "Tell me a joke"}]]
|
| outputs = self.client.chat(messages)
|
| prompt_ids = outputs["prompt_ids"]
|
| completion_ids = outputs["completion_ids"]
|
|
|
|
|
| assert isinstance(prompt_ids, list)
|
| assert isinstance(completion_ids, list)
|
|
|
|
|
| assert len(prompt_ids) == len(messages)
|
| assert len(completion_ids) == len(messages)
|
|
|
|
|
| for seq in prompt_ids:
|
| assert all(isinstance(tok, int) for tok in seq)
|
| for seq in completion_ids:
|
| assert all(isinstance(tok, int) for tok in seq)
|
|
|
| def test_chat_with_logprobs_none(self):
|
| outputs = self.client.chat([[{"role": "user", "content": "Hello, AI!"}]], logprobs=None)
|
|
|
| assert isinstance(outputs["prompt_ids"], list)
|
| assert isinstance(outputs["completion_ids"], list)
|
| assert outputs["logprobs"] is None
|
| assert outputs["logprob_token_ids"] is None
|
|
|
| def test_chat_with_tools(self):
|
| def multiply(a: int, b: int) -> int:
|
| """
|
| Multiplies two integers.
|
|
|
| Args:
|
| a: The first integer.
|
| b: The second integer.
|
|
|
| Returns:
|
| The product of the two integers.
|
| """
|
| return a * b
|
|
|
| messages = [[{"role": "user", "content": "What is 3 multiplied by 4?"}]]
|
| outputs = self.client.chat(messages, tools=[multiply])
|
|
|
|
|
| tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
| decoded_prompt = tokenizer.decode(outputs["prompt_ids"][0])
|
| assert "Multiplies two integers." in decoded_prompt
|
|
|
| def test_generate_with_token_ids(self):
|
| tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
| prompts = ["Hello, AI!", "Tell me a joke"]
|
| prompt_token_ids = tokenizer(prompts)["input_ids"]
|
| outputs = self.client.generate(prompt_token_ids)
|
| prompt_ids = outputs["prompt_ids"]
|
| completion_ids = outputs["completion_ids"]
|
|
|
|
|
| assert isinstance(prompt_ids, list)
|
| assert isinstance(completion_ids, list)
|
|
|
|
|
| assert len(prompt_ids) == len(prompts)
|
| assert len(completion_ids) == len(prompts)
|
|
|
|
|
| assert prompt_ids == prompt_token_ids
|
|
|
|
|
| for seq in prompt_ids:
|
| assert all(isinstance(tok, int) for tok in seq)
|
| for seq in completion_ids:
|
| assert all(isinstance(tok, int) for tok in seq)
|
|
|
| def test_generate_with_params(self):
|
| prompts = ["Hello, AI!", "Tell me a joke"]
|
| completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[
|
| "completion_ids"
|
| ]
|
|
|
|
|
| assert isinstance(completion_ids, list)
|
|
|
|
|
| assert len(completion_ids) == 2 * len(prompts)
|
|
|
|
|
| for seq in completion_ids:
|
| assert all(isinstance(tok, int) for tok in seq)
|
|
|
|
|
| for seq in completion_ids:
|
| assert len(seq) <= 32
|
|
|
| def test_update_model_params(self):
|
| model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device)
|
| self.client.update_model_params(model)
|
|
|
| def test_reset_prefix_cache(self):
|
|
|
| self.client.reset_prefix_cache()
|
|
|
| @classmethod
|
| def teardown_class(cls):
|
|
|
| cls.client.close_communicator()
|
|
|
|
|
|
|
| kill_process(cls.server_process)
|
|
|
|
|
| @pytest.mark.slow
|
| @require_3_accelerators
|
| @require_vllm
|
| class TestVLLMClientServerTP(TrlTestCase):
|
| model_id = "Qwen/Qwen2.5-1.5B"
|
|
|
| @classmethod
|
| def setup_class(cls):
|
|
|
| env = os.environ.copy()
|
| VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES"
|
| env[VISIBLE_DEVICES] = "1,2"
|
|
|
|
|
| cls.server_process = subprocess.Popen(
|
| ["trl", "vllm-serve", "--model", cls.model_id, "--tensor_parallel_size", "2"],
|
| stdout=subprocess.PIPE,
|
| stderr=subprocess.PIPE,
|
| env=env,
|
| )
|
|
|
|
|
| cls.client = VLLMClient(connection_timeout=240, host="localhost")
|
| cls.client.init_communicator()
|
|
|
| def test_generate(self):
|
| prompts = ["Hello, AI!", "Tell me a joke"]
|
| outputs = self.client.generate(prompts)
|
| prompt_ids = outputs["prompt_ids"]
|
| completion_ids = outputs["completion_ids"]
|
|
|
|
|
| assert isinstance(prompt_ids, list)
|
| assert isinstance(completion_ids, list)
|
|
|
|
|
| assert len(prompt_ids) == len(prompts)
|
| assert len(completion_ids) == len(prompts)
|
|
|
|
|
| for seq in prompt_ids:
|
| assert all(isinstance(tok, int) for tok in seq)
|
| for seq in completion_ids:
|
| assert all(isinstance(tok, int) for tok in seq)
|
|
|
| def test_generate_with_logprobs_none(self):
|
| outputs = self.client.generate(["Hello, AI!"], logprobs=None)
|
|
|
| assert isinstance(outputs["prompt_ids"], list)
|
| assert isinstance(outputs["completion_ids"], list)
|
| assert outputs["logprobs"] is None
|
| assert outputs["logprob_token_ids"] is None
|
|
|
| def test_chat(self):
|
| messages = [[{"role": "user", "content": "Hello, AI!"}], [{"role": "user", "content": "Tell me a joke"}]]
|
| outputs = self.client.chat(messages)
|
| prompt_ids = outputs["prompt_ids"]
|
| completion_ids = outputs["completion_ids"]
|
|
|
|
|
| assert isinstance(prompt_ids, list)
|
| assert isinstance(completion_ids, list)
|
|
|
|
|
| assert len(prompt_ids) == len(messages)
|
| assert len(completion_ids) == len(messages)
|
|
|
|
|
| for seq in prompt_ids:
|
| assert all(isinstance(tok, int) for tok in seq)
|
| for seq in completion_ids:
|
| assert all(isinstance(tok, int) for tok in seq)
|
|
|
| def test_chat_with_logprobs_none(self):
|
| outputs = self.client.chat([[{"role": "user", "content": "Hello, AI!"}]], logprobs=None)
|
|
|
| assert isinstance(outputs["prompt_ids"], list)
|
| assert isinstance(outputs["completion_ids"], list)
|
| assert outputs["logprobs"] is None
|
| assert outputs["logprob_token_ids"] is None
|
|
|
| def test_chat_with_tools(self):
|
| def multiply(a: int, b: int) -> int:
|
| """
|
| Multiplies two integers.
|
|
|
| Args:
|
| a: The first integer.
|
| b: The second integer.
|
|
|
| Returns:
|
| The product of the two integers.
|
| """
|
| return a * b
|
|
|
| messages = [[{"role": "user", "content": "What is 3 multiplied by 4?"}]]
|
| outputs = self.client.chat(messages, tools=[multiply])
|
|
|
|
|
| tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
| decoded_prompt = tokenizer.decode(outputs["prompt_ids"][0])
|
| assert "Multiplies two integers." in decoded_prompt
|
|
|
| def test_generate_with_token_ids(self):
|
| tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
| prompts = ["Hello, AI!", "Tell me a joke"]
|
| prompt_token_ids = tokenizer(prompts)["input_ids"]
|
| outputs = self.client.generate(prompt_token_ids)
|
| prompt_ids = outputs["prompt_ids"]
|
| completion_ids = outputs["completion_ids"]
|
|
|
|
|
| assert isinstance(prompt_ids, list)
|
| assert isinstance(completion_ids, list)
|
|
|
|
|
| assert len(prompt_ids) == len(prompts)
|
| assert len(completion_ids) == len(prompts)
|
|
|
|
|
| assert prompt_ids == prompt_token_ids
|
|
|
|
|
| for seq in prompt_ids:
|
| assert all(isinstance(tok, int) for tok in seq)
|
| for seq in completion_ids:
|
| assert all(isinstance(tok, int) for tok in seq)
|
|
|
| def test_generate_with_params(self):
|
| prompts = ["Hello, AI!", "Tell me a joke"]
|
| completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[
|
| "completion_ids"
|
| ]
|
|
|
|
|
| assert isinstance(completion_ids, list)
|
|
|
|
|
| assert len(completion_ids) == 2 * len(prompts)
|
|
|
|
|
| for seq in completion_ids:
|
| assert all(isinstance(tok, int) for tok in seq)
|
|
|
|
|
| for seq in completion_ids:
|
| assert len(seq) <= 32
|
|
|
| def test_update_model_params(self):
|
| model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device)
|
| self.client.update_model_params(model)
|
|
|
| def test_reset_prefix_cache(self):
|
|
|
| self.client.reset_prefix_cache()
|
|
|
| @classmethod
|
| def teardown_class(cls):
|
|
|
| cls.client.close_communicator()
|
|
|
|
|
|
|
| kill_process(cls.server_process)
|
|
|
|
|
| @pytest.mark.slow
|
| @pytest.mark.skipif(
|
| _is_vllm_ge_014,
|
| reason="Skipping DP server test for vLLM>=0.14.0 (PR vllm#30739: DP for non-MoE/dense models no longer supported).",
|
| )
|
| @require_3_accelerators
|
| @require_vllm
|
| class TestVLLMClientServerDP(TrlTestCase):
|
| model_id = "Qwen/Qwen2.5-1.5B"
|
|
|
| @classmethod
|
| def setup_class(cls):
|
|
|
| env = os.environ.copy()
|
| VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES"
|
| env[VISIBLE_DEVICES] = "1,2"
|
|
|
|
|
| cls.server_process = subprocess.Popen(
|
| ["trl", "vllm-serve", "--model", cls.model_id, "--data_parallel_size", "2"],
|
| stdout=subprocess.PIPE,
|
| stderr=subprocess.PIPE,
|
| env=env,
|
| )
|
|
|
|
|
| cls.client = VLLMClient(connection_timeout=240, host="localhost")
|
| cls.client.init_communicator()
|
|
|
| def test_generate(self):
|
| prompts = ["Hello, AI!", "Tell me a joke"]
|
| outputs = self.client.generate(prompts)
|
| prompt_ids = outputs["prompt_ids"]
|
| completion_ids = outputs["completion_ids"]
|
|
|
|
|
| assert isinstance(prompt_ids, list)
|
| assert isinstance(completion_ids, list)
|
|
|
|
|
| assert len(prompt_ids) == len(prompts)
|
| assert len(completion_ids) == len(prompts)
|
|
|
|
|
| for seq in prompt_ids:
|
| assert all(isinstance(tok, int) for tok in seq)
|
| for seq in completion_ids:
|
| assert all(isinstance(tok, int) for tok in seq)
|
|
|
| def test_generate_with_logprobs_none(self):
|
| outputs = self.client.generate(["Hello, AI!"], logprobs=None)
|
|
|
| assert isinstance(outputs["prompt_ids"], list)
|
| assert isinstance(outputs["completion_ids"], list)
|
| assert outputs["logprobs"] is None
|
| assert outputs["logprob_token_ids"] is None
|
|
|
| def test_chat(self):
|
| messages = [[{"role": "user", "content": "Hello, AI!"}], [{"role": "user", "content": "Tell me a joke"}]]
|
| outputs = self.client.chat(messages)
|
| prompt_ids = outputs["prompt_ids"]
|
| completion_ids = outputs["completion_ids"]
|
|
|
|
|
| assert isinstance(prompt_ids, list)
|
| assert isinstance(completion_ids, list)
|
|
|
|
|
| assert len(prompt_ids) == len(messages)
|
| assert len(completion_ids) == len(messages)
|
|
|
|
|
| for seq in prompt_ids:
|
| assert all(isinstance(tok, int) for tok in seq)
|
| for seq in completion_ids:
|
| assert all(isinstance(tok, int) for tok in seq)
|
|
|
| def test_chat_with_logprobs_none(self):
|
| outputs = self.client.chat([[{"role": "user", "content": "Hello, AI!"}]], logprobs=None)
|
|
|
| assert isinstance(outputs["prompt_ids"], list)
|
| assert isinstance(outputs["completion_ids"], list)
|
| assert outputs["logprobs"] is None
|
| assert outputs["logprob_token_ids"] is None
|
|
|
| def test_chat_with_tools(self):
|
| def multiply(a: int, b: int) -> int:
|
| """
|
| Multiplies two integers.
|
|
|
| Args:
|
| a: The first integer.
|
| b: The second integer.
|
|
|
| Returns:
|
| The product of the two integers.
|
| """
|
| return a * b
|
|
|
| messages = [[{"role": "user", "content": "What is 3 multiplied by 4?"}]]
|
| outputs = self.client.chat(messages, tools=[multiply])
|
|
|
|
|
| tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
| decoded_prompt = tokenizer.decode(outputs["prompt_ids"][0])
|
| assert "Multiplies two integers." in decoded_prompt
|
|
|
| def test_generate_with_token_ids(self):
|
| tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
| prompts = ["Hello, AI!", "Tell me a joke"]
|
| prompt_token_ids = tokenizer(prompts)["input_ids"]
|
| outputs = self.client.generate(prompt_token_ids)
|
| prompt_ids = outputs["prompt_ids"]
|
| completion_ids = outputs["completion_ids"]
|
|
|
|
|
| assert isinstance(prompt_ids, list)
|
| assert isinstance(completion_ids, list)
|
|
|
|
|
| assert len(prompt_ids) == len(prompts)
|
| assert len(completion_ids) == len(prompts)
|
|
|
|
|
| assert prompt_ids == prompt_token_ids
|
|
|
|
|
| for seq in prompt_ids:
|
| assert all(isinstance(tok, int) for tok in seq)
|
| for seq in completion_ids:
|
| assert all(isinstance(tok, int) for tok in seq)
|
|
|
| def test_generate_with_params(self):
|
| prompts = ["Hello, AI!", "Tell me a joke"]
|
| completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[
|
| "completion_ids"
|
| ]
|
|
|
|
|
| assert isinstance(completion_ids, list)
|
|
|
|
|
| assert len(completion_ids) == 2 * len(prompts)
|
|
|
|
|
| for seq in completion_ids:
|
| assert all(isinstance(tok, int) for tok in seq)
|
|
|
|
|
| for seq in completion_ids:
|
| assert len(seq) <= 32
|
|
|
| def test_update_model_params(self):
|
| model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device)
|
| self.client.update_model_params(model)
|
|
|
| def test_reset_prefix_cache(self):
|
|
|
| self.client.reset_prefix_cache()
|
|
|
| @classmethod
|
| def teardown_class(cls):
|
|
|
| cls.client.close_communicator()
|
|
|
|
|
|
|
| kill_process(cls.server_process)
|
|
|
|
|
| @pytest.mark.slow
|
| @require_torch_multi_accelerator
|
| @require_vllm
|
| class TestVLLMClientServerDeviceParameter(TrlTestCase):
|
| """Test the device parameter functionality in init_communicator."""
|
|
|
| model_id = "Qwen/Qwen2.5-1.5B"
|
|
|
| @classmethod
|
| def setup_class(cls):
|
|
|
| env = os.environ.copy()
|
| VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES"
|
| env[VISIBLE_DEVICES] = "1"
|
|
|
|
|
| cls.server_process = subprocess.Popen(
|
| ["trl", "vllm-serve", "--model", cls.model_id], stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env
|
| )
|
|
|
| def test_init_communicator_with_device_int(self):
|
| """Test init_communicator with integer device parameter."""
|
| client = VLLMClient(connection_timeout=240, host="localhost")
|
| client.init_communicator(device=0)
|
|
|
|
|
| prompts = ["Hello, AI!"]
|
| outputs = client.generate(prompts)
|
| prompt_ids = outputs["prompt_ids"]
|
| completion_ids = outputs["completion_ids"]
|
| assert isinstance(prompt_ids, list)
|
| assert len(prompt_ids) == len(prompts)
|
| assert isinstance(completion_ids, list)
|
| assert len(completion_ids) == len(prompts)
|
|
|
| client.close_communicator()
|
|
|
| def test_init_communicator_with_device_string(self):
|
| """Test init_communicator with string device parameter."""
|
| client = VLLMClient(connection_timeout=240, host="localhost")
|
| client.init_communicator(device=0)
|
|
|
|
|
| prompts = ["Hello, AI!"]
|
| outputs = client.generate(prompts)["completion_ids"]
|
| assert isinstance(outputs, list)
|
| assert len(outputs) == len(prompts)
|
|
|
| client.close_communicator()
|
|
|
| def test_init_communicator_with_torch_device(self):
|
| """Test init_communicator with torch.device object."""
|
| import torch
|
|
|
| client = VLLMClient(connection_timeout=240, host="localhost")
|
| device = torch.device(0)
|
| client.init_communicator(device=device)
|
|
|
|
|
| prompts = ["Hello, AI!"]
|
| outputs = client.generate(prompts)["completion_ids"]
|
| assert isinstance(outputs, list)
|
| assert len(outputs) == len(prompts)
|
|
|
| client.close_communicator()
|
|
|
| @classmethod
|
| def teardown_class(cls):
|
|
|
|
|
| kill_process(cls.server_process)
|
|
|
|
|
| @pytest.mark.slow
|
| @require_vllm
|
| @require_vision
|
| class TestVLLMClientServerVLM(TrlTestCase):
|
| model_id = "Qwen/Qwen2.5-VL-3B-Instruct"
|
|
|
| @classmethod
|
| def setup_class(cls):
|
|
|
| cls.server_process = subprocess.Popen(
|
| ["trl", "vllm-serve", "--model", cls.model_id], stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
| )
|
|
|
|
|
| cls.client = VLLMClient(connection_timeout=240, host="localhost")
|
|
|
| def test_generate_with_token_ids_and_image(self):
|
| from PIL import Image
|
|
|
| processor = AutoProcessor.from_pretrained(self.model_id)
|
| image1 = Image.new("RGB", (64, 64), color="red")
|
| image2 = Image.new("RGB", (64, 64), color="blue")
|
| image3 = Image.new("RGB", (64, 64), color="green")
|
| messages = [
|
| [
|
| {
|
| "role": "user",
|
| "content": [
|
| {"type": "image", "image": image1},
|
| {"type": "image", "image": image2},
|
| {"type": "text", "text": "What are the differences between these two images?"},
|
| ],
|
| }
|
| ],
|
| [
|
| {
|
| "role": "user",
|
| "content": [
|
| {"type": "image", "image": image3},
|
| {"type": "text", "text": "What is the color of this image?"},
|
| ],
|
| }
|
| ],
|
| ]
|
| prompt_token_ids = processor.apply_chat_template(
|
| conversation=messages, tokenize=True, add_generation_prompt=True
|
| )
|
| outputs = self.client.generate(prompt_token_ids, images=[[image1, image2], [image3]], max_tokens=64)
|
| prompt_ids = outputs["prompt_ids"]
|
| completion_ids = outputs["completion_ids"]
|
|
|
| assert len(prompt_ids) == 2
|
| assert len(completion_ids) == 2
|
| assert all(isinstance(tok, int) for tok in prompt_ids[0])
|
| assert all(isinstance(tok, int) for tok in completion_ids[0])
|
|
|
| def test_generate_with_token_ids_mixed_images(self):
|
| """Test a batch where one prompt has an image and the other does not."""
|
| from PIL import Image
|
|
|
| processor = AutoProcessor.from_pretrained(self.model_id)
|
| image = Image.new("RGB", (64, 64), color="red")
|
| messages = [
|
| [
|
| {
|
| "role": "user",
|
| "content": [{"type": "image", "image": image}, {"type": "text", "text": "Describe this image."}],
|
| }
|
| ],
|
| [
|
| {
|
| "role": "user",
|
| "content": [{"type": "text", "text": "What is 1+1?"}],
|
| }
|
| ],
|
| ]
|
| prompt_token_ids = processor.apply_chat_template(
|
| conversation=messages, tokenize=True, add_generation_prompt=True
|
| )
|
| outputs = self.client.generate(prompt_token_ids, images=[[image], None], max_tokens=64)
|
| prompt_ids = outputs["prompt_ids"]
|
| completion_ids = outputs["completion_ids"]
|
|
|
| assert len(prompt_ids) == 2
|
| assert len(completion_ids) == 2
|
| assert all(isinstance(tok, int) for tok in prompt_ids[0])
|
| assert all(isinstance(tok, int) for tok in prompt_ids[1])
|
| assert all(isinstance(tok, int) for tok in completion_ids[0])
|
| assert all(isinstance(tok, int) for tok in completion_ids[1])
|
|
|
| @classmethod
|
| def teardown_class(cls):
|
| kill_process(cls.server_process)
|
|
|