Spaces:
Paused
Paused
| # Copyright 2020-2025 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import atexit | |
| import logging | |
| import socket | |
| import time | |
| from typing import Optional | |
| from urllib.parse import urlparse | |
| import torch | |
| from torch import nn | |
| from ..import_utils import is_requests_available, is_vllm_ascend_available, is_vllm_available | |
| if is_requests_available(): | |
| import requests | |
| from requests import ConnectionError | |
| if is_vllm_available(): | |
| from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator | |
| from vllm.distributed.utils import StatelessProcessGroup | |
| if is_vllm_ascend_available(): | |
| from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator as PyNcclCommunicator | |
| logger = logging.getLogger(__name__) | |
| class VLLMClient: | |
| """ | |
| A client class to interact with a vLLM server. | |
| This class provides methods to generate completions, initialize and manage weight update groups, and update model | |
| weights in a distributed setting. Before using it, start the vLLM server with `trl vllm-serve`. | |
| Args: | |
| base_url (`str` or `None`, *optional*, defaults to `None`): | |
| Base URL for the vLLM server (e.g., `"http://localhost:8000"`). If provided, `host` and `server_port` are | |
| ignored. | |
| host (`str`, *optional*, defaults to `"0.0.0.0"`): | |
| IP address of the vLLM server. Ignored if `base_url` is provided. | |
| server_port (`int`, *optional*, defaults to `8000`): | |
| Port number of the vLLM server. Ignored if `base_url` is provided. | |
| group_port (`int`, *optional*, defaults to `51216`): | |
| Port number for the weight update group. | |
| connection_timeout (`float`, *optional*, defaults to `0.0`): | |
| Total timeout duration in seconds to wait for the server to be up. If the server is not up after the | |
| timeout, a `ConnectionError` is raised. | |
| Examples: | |
| Run the vLLM server with the model `Qwen/Qwen2.5-7B`: | |
| ``` | |
| $ trl vllm-serve --model Qwen/Qwen2.5-7B | |
| ... | |
| INFO: Application startup complete. | |
| INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit) | |
| ``` | |
| Use the client to generate completions and update model weights: | |
| ```python | |
| >>> from trl.extras.vllm_client import VLLMClient | |
| >>> client = VLLMClient() | |
| >>> client.generate(["Hello, AI!", "Tell me a joke"]) | |
| [[2980, 498, 1492, 752, 448, 264, 13027, 8645, 30, 358, 2776, 4460, 311, 3270, 264, 2025], | |
| [911, 7988, 1251, 382, 3838, 653, 498, 1618, 4325, 879, 2581, 20027, 264, 21428, 30, 362]] | |
| >>> from transformers import AutoModelForCausalLM | |
| >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B", device_map="cuda") | |
| >>> client.init_communicator() | |
| >>> client.update_model_params(model) | |
| ``` | |
| There are several ways to initialize the client: | |
| ```python | |
| VLLMClient(base_url="http://localhost:8000") | |
| VLLMClient(base_url="http://192.168.1.100:8000") | |
| VLLMClient(host="localhost", server_port=8000) | |
| VLLMClient(host="192.168.1.100", server_port=8000) | |
| ``` | |
| """ | |
| def __init__( | |
| self, | |
| base_url: Optional[str] = None, | |
| host: str = "0.0.0.0", | |
| server_port: int = 8000, | |
| group_port: int = 51216, | |
| connection_timeout: float = 0.0, | |
| ): | |
| if not is_requests_available(): | |
| raise ImportError("requests is not installed. Please install it with `pip install requests`.") | |
| if not is_vllm_available(): | |
| raise ImportError("vLLM is not installed. Please install it with `pip install vllm`.") | |
| self.session = requests.Session() | |
| if base_url is not None: | |
| # Parse the base_url to extract host and port | |
| parsed_url = urlparse(base_url) | |
| self.host = socket.gethostbyname(parsed_url.hostname) | |
| scheme = parsed_url.scheme or "http" | |
| self.base_url = f"{scheme}://{parsed_url.netloc}{parsed_url.path}" | |
| else: | |
| self.host = host | |
| self.server_port = server_port | |
| self.base_url = f"http://{self.host}:{self.server_port}" | |
| self.group_port = group_port | |
| self.check_server(connection_timeout) # check server and fail after timeout | |
| def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0): | |
| """ | |
| Check server availability with retries on failure, within a total timeout duration. If the server is not up | |
| after the total timeout duration, raise a `ConnectionError`. | |
| Args: | |
| retry_interval (`float`, *optional*, defaults to `2.0`): | |
| Interval in seconds between retries. | |
| total_timeout (`float`, *optional*, defaults to `0.0`): | |
| Total timeout duration in seconds. | |
| """ | |
| url = f"{self.base_url}/health/" | |
| start_time = time.time() # Record the start time | |
| while True: | |
| try: | |
| response = requests.get(url) | |
| except requests.exceptions.RequestException as exc: | |
| # Check if the total timeout duration has passed | |
| elapsed_time = time.time() - start_time | |
| if elapsed_time >= total_timeout: | |
| raise ConnectionError( | |
| f"The vLLM server can't be reached at {self.base_url} after {total_timeout} seconds. Make " | |
| "sure the server is running by running `trl vllm-serve`." | |
| ) from exc | |
| else: | |
| if response.status_code == 200: | |
| if "X-Forwarded-For" in response.headers: | |
| self.host = response.headers["X-Forwarded-For"] | |
| logger.info("Server is up!") | |
| return None | |
| # Retry logic: wait before trying again | |
| logger.info(f"Server is not up yet. Retrying in {retry_interval} seconds...") | |
| time.sleep(retry_interval) | |
| def generate( | |
| self, | |
| prompts: list[str], | |
| n: int = 1, | |
| repetition_penalty: float = 1.0, | |
| temperature: float = 1.0, | |
| top_p: float = 1.0, | |
| top_k: int = -1, | |
| min_p: float = 0.0, | |
| max_tokens: int = 16, | |
| guided_decoding_regex: Optional[str] = None, | |
| ) -> list[list[int]]: | |
| """ | |
| Generates model completions for the provided prompts. | |
| Args: | |
| prompts (`list[str]`): | |
| List of text prompts for which the model will generate completions. | |
| n (`int`, *optional*, defaults to `1`): | |
| Number of completions to generate for each prompt. | |
| repetition_penalty (`float`, *optional*, defaults to `1.0`): | |
| Parameter for repetition penalty. 1.0 means no penalty. | |
| temperature (`float`, *optional*, defaults to `1.0`): | |
| Temperature parameter for sampling. Higher values increase diversity. | |
| top_p (`float`, *optional*, defaults to `1.0`): | |
| Top-p sampling parameter.`1.0` means no truncation. | |
| top_k (`int`, *optional*, defaults to `-1`): | |
| Top-k sampling parameter. `-1` means no truncation. | |
| min_p (`float`, *optional*, defaults to `0.0`): | |
| Minimum probability for sampling. | |
| max_tokens (`int`, *optional*, defaults to `16`): | |
| Maximum number of tokens to generate for each prompt. | |
| guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`): | |
| Regular expression to guide the decoding process. | |
| Returns: | |
| `list[list[int]]`: | |
| List of lists of token IDs representing the model-generated completions for each prompt. | |
| """ | |
| url = f"{self.base_url}/generate/" | |
| response = self.session.post( | |
| url, | |
| json={ | |
| "prompts": prompts, | |
| "n": n, | |
| "repetition_penalty": repetition_penalty, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "top_k": top_k, | |
| "min_p": min_p, | |
| "max_tokens": max_tokens, | |
| "guided_decoding_regex": guided_decoding_regex, | |
| }, | |
| ) | |
| if response.status_code == 200: | |
| return response.json()["completion_ids"] | |
| else: | |
| raise Exception(f"Request failed: {response.status_code}, {response.text}") | |
| def init_communicator(self): | |
| """ | |
| Initializes the weight update group in a distributed setup for model synchronization. | |
| """ | |
| # Get the world size from the server | |
| url = f"{self.base_url}/get_world_size/" | |
| response = requests.get(url) | |
| if response.status_code == 200: | |
| vllm_world_size = response.json()["world_size"] | |
| else: | |
| raise Exception(f"Request failed: {response.status_code}, {response.text}") | |
| world_size = vllm_world_size + 1 # add the client to the world | |
| self.rank = vllm_world_size # the client's rank is the last process | |
| # Initialize weight update group | |
| url = f"{self.base_url}/init_communicator/" | |
| # In the server side, the host is set to 0.0.0.0 | |
| response = self.session.post(url, json={"host": "0.0.0.0", "port": self.group_port, "world_size": world_size}) | |
| if response.status_code != 200: | |
| raise Exception(f"Request failed: {response.status_code}, {response.text}") | |
| # Brief delay to allow server initialization. While not strictly required (client socket will retry on | |
| # connection failure), this prevents log warnings like: | |
| # [W416 23:24:57.460001114 socket.cpp:204] [c10d] The hostname of the client socket cannot be retrieved. err=-3 | |
| time.sleep(0.1) | |
| # Set up the communication group for weight broadcasting | |
| pg = StatelessProcessGroup.create(host=self.host, port=self.group_port, rank=self.rank, world_size=world_size) | |
| self.pynccl_comm = PyNcclCommunicator(pg, device=0) | |
| # When the client object is deleted, close the weight update group | |
| atexit.register(self.close_communicator) | |
| def update_named_param(self, name: str, weights: torch.Tensor): | |
| """ | |
| Updates a specific named parameter in the model and broadcasts it to other processes. | |
| Args: | |
| name (`str`): | |
| Name of the layer whose weights are being updated. | |
| weights (`torch.Tensor`): | |
| Tensor containing the updated weights. | |
| """ | |
| dtype, shape = str(weights.dtype), tuple(weights.shape) | |
| url = f"{self.base_url}/update_named_param/" | |
| response = self.session.post(url, json={"name": name, "dtype": dtype, "shape": shape}) | |
| if response.status_code != 200: | |
| raise Exception(f"Request failed: {response.status_code}, {response.text}") | |
| # Broadcast the weights to the other processes | |
| self.pynccl_comm.broadcast(weights, src=self.rank) | |
| self.pynccl_comm.group.barrier() | |
| def update_model_params(self, model: nn.Module): | |
| """ | |
| Updates all parameters of the given model by calling `update_named_param` for each parameter in the model. | |
| Args: | |
| model (`nn.Module`): | |
| Model whose parameters (weights/biases) are to be updated. | |
| """ | |
| for name, param in model.named_parameters(): | |
| # Update each parameter individually | |
| self.update_named_param(name, param.data) | |
| def reset_prefix_cache(self): | |
| """ | |
| Resets the prefix cache for the model. | |
| """ | |
| url = f"{self.base_url}/reset_prefix_cache/" | |
| response = self.session.post(url) | |
| if response.status_code != 200: | |
| raise Exception(f"Request failed: {response.status_code}, {response.text}") | |
| def close_communicator(self): | |
| """ | |
| Closes the weight update group and cleans up the communication group. | |
| """ | |
| url = f"{self.base_url}/close_communicator/" | |
| try: | |
| response = self.session.post(url) | |
| except ConnectionError: | |
| # The server might be already down, so we don't need to close the communicator | |
| pass | |
| else: | |
| if response.status_code != 200: | |
| raise Exception(f"Request failed: {response.status_code}, {response.text}") | |
| # Example usage | |
| if __name__ == "__main__": | |
| from vllm import SamplingParams | |
| client = VLLMClient() | |
| client.init_communicator() | |
| # Generate completions | |
| responses = client.generate(["Hello, AI!", "Tell me a joke"], n=4, max_tokens=32, sampling_params=SamplingParams()) | |
| print("Responses:", responses) # noqa | |
| # Update model weights | |
| from transformers import AutoModelForCausalLM | |
| model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B").to("cuda") | |
| client.update_model_params(model) | |