| | """ |
| | A controller manages distributed workers. |
| | It sends worker addresses to clients. |
| | """ |
| | import argparse |
| | import asyncio |
| | import dataclasses |
| | from enum import Enum, auto |
| | import json |
| | import logging |
| | import time |
| | from typing import List, Union |
| | import threading |
| |
|
| | from fastapi import FastAPI, Request |
| | from fastapi.responses import StreamingResponse |
| | import numpy as np |
| | import requests |
| | import uvicorn |
| |
|
| | from llava.constants import CONTROLLER_HEART_BEAT_EXPIRATION |
| | from llava.utils import build_logger, server_error_msg |
| |
|
| |
|
| | logger = build_logger("controller", "controller.log") |
| |
|
| |
|
| | class DispatchMethod(Enum): |
| | LOTTERY = auto() |
| | SHORTEST_QUEUE = auto() |
| |
|
| | @classmethod |
| | def from_str(cls, name): |
| | if name == "lottery": |
| | return cls.LOTTERY |
| | elif name == "shortest_queue": |
| | return cls.SHORTEST_QUEUE |
| | else: |
| | raise ValueError(f"Invalid dispatch method") |
| |
|
| |
|
| | @dataclasses.dataclass |
| | class WorkerInfo: |
| | model_names: List[str] |
| | speed: int |
| | queue_length: int |
| | check_heart_beat: bool |
| | last_heart_beat: str |
| |
|
| |
|
| | def heart_beat_controller(controller): |
| | while True: |
| | time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION) |
| | controller.remove_stable_workers_by_expiration() |
| |
|
| |
|
| | class Controller: |
| | def __init__(self, dispatch_method: str): |
| | |
| | self.worker_info = {} |
| | self.dispatch_method = DispatchMethod.from_str(dispatch_method) |
| |
|
| | self.heart_beat_thread = threading.Thread( |
| | target=heart_beat_controller, args=(self,), daemon=True) |
| | self.heart_beat_thread.start() |
| |
|
| | logger.info("Init controller") |
| |
|
| | def register_worker(self, worker_name: str, check_heart_beat: bool, |
| | worker_status: dict): |
| | if worker_name not in self.worker_info: |
| | logger.info(f"Register a new worker: {worker_name}") |
| | else: |
| | logger.info(f"Register an existing worker: {worker_name}") |
| |
|
| | if not worker_status: |
| | worker_status = self.get_worker_status(worker_name) |
| | if not worker_status: |
| | return False |
| |
|
| | self.worker_info[worker_name] = WorkerInfo( |
| | worker_status["model_names"], worker_status["speed"], worker_status["queue_length"], |
| | check_heart_beat, time.time()) |
| |
|
| | logger.info(f"Register done: {worker_name}, {worker_status}") |
| | return True |
| |
|
| | def get_worker_status(self, worker_name: str): |
| | try: |
| | r = requests.post(worker_name + "/worker_get_status", timeout=5) |
| | except requests.exceptions.RequestException as e: |
| | logger.error(f"Get status fails: {worker_name}, {e}") |
| | return None |
| |
|
| | if r.status_code != 200: |
| | logger.error(f"Get status fails: {worker_name}, {r}") |
| | return None |
| |
|
| | return r.json() |
| |
|
| | def remove_worker(self, worker_name: str): |
| | del self.worker_info[worker_name] |
| |
|
| | def refresh_all_workers(self): |
| | old_info = dict(self.worker_info) |
| | self.worker_info = {} |
| |
|
| | for w_name, w_info in old_info.items(): |
| | if not self.register_worker(w_name, w_info.check_heart_beat, None): |
| | logger.info(f"Remove stale worker: {w_name}") |
| |
|
| | def list_models(self): |
| | model_names = set() |
| |
|
| | for w_name, w_info in self.worker_info.items(): |
| | model_names.update(w_info.model_names) |
| |
|
| | return list(model_names) |
| |
|
| | def get_worker_address(self, model_name: str): |
| | if self.dispatch_method == DispatchMethod.LOTTERY: |
| | worker_names = [] |
| | worker_speeds = [] |
| | for w_name, w_info in self.worker_info.items(): |
| | if model_name in w_info.model_names: |
| | worker_names.append(w_name) |
| | worker_speeds.append(w_info.speed) |
| | worker_speeds = np.array(worker_speeds, dtype=np.float32) |
| | norm = np.sum(worker_speeds) |
| | if norm < 1e-4: |
| | return "" |
| | worker_speeds = worker_speeds / norm |
| | if True: |
| | pt = np.random.choice(np.arange(len(worker_names)), |
| | p=worker_speeds) |
| | worker_name = worker_names[pt] |
| | return worker_name |
| |
|
| | |
| | while True: |
| | pt = np.random.choice(np.arange(len(worker_names)), |
| | p=worker_speeds) |
| | worker_name = worker_names[pt] |
| |
|
| | if self.get_worker_status(worker_name): |
| | break |
| | else: |
| | self.remove_worker(worker_name) |
| | worker_speeds[pt] = 0 |
| | norm = np.sum(worker_speeds) |
| | if norm < 1e-4: |
| | return "" |
| | worker_speeds = worker_speeds / norm |
| | continue |
| | return worker_name |
| | elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE: |
| | worker_names = [] |
| | worker_qlen = [] |
| | for w_name, w_info in self.worker_info.items(): |
| | if model_name in w_info.model_names: |
| | worker_names.append(w_name) |
| | worker_qlen.append(w_info.queue_length / w_info.speed) |
| | if len(worker_names) == 0: |
| | return "" |
| | min_index = np.argmin(worker_qlen) |
| | w_name = worker_names[min_index] |
| | self.worker_info[w_name].queue_length += 1 |
| | logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}") |
| | return w_name |
| | else: |
| | raise ValueError(f"Invalid dispatch method: {self.dispatch_method}") |
| |
|
| | def receive_heart_beat(self, worker_name: str, queue_length: int): |
| | if worker_name not in self.worker_info: |
| | logger.info(f"Receive unknown heart beat. {worker_name}") |
| | return False |
| |
|
| | self.worker_info[worker_name].queue_length = queue_length |
| | self.worker_info[worker_name].last_heart_beat = time.time() |
| | logger.info(f"Receive heart beat. {worker_name}") |
| | return True |
| |
|
| | def remove_stable_workers_by_expiration(self): |
| | expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION |
| | to_delete = [] |
| | for worker_name, w_info in self.worker_info.items(): |
| | if w_info.check_heart_beat and w_info.last_heart_beat < expire: |
| | to_delete.append(worker_name) |
| |
|
| | for worker_name in to_delete: |
| | self.remove_worker(worker_name) |
| |
|
| | def worker_api_generate_stream(self, params): |
| | worker_addr = self.get_worker_address(params["model"]) |
| | if not worker_addr: |
| | logger.info(f"no worker: {params['model']}") |
| | ret = { |
| | "text": server_error_msg, |
| | "error_code": 2, |
| | } |
| | yield json.dumps(ret).encode() + b"\0" |
| |
|
| | try: |
| | response = requests.post(worker_addr + "/worker_generate_stream", |
| | json=params, stream=True, timeout=5) |
| | for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): |
| | if chunk: |
| | yield chunk + b"\0" |
| | except requests.exceptions.RequestException as e: |
| | logger.info(f"worker timeout: {worker_addr}") |
| | ret = { |
| | "text": server_error_msg, |
| | "error_code": 3, |
| | } |
| | yield json.dumps(ret).encode() + b"\0" |
| |
|
| |
|
| | |
| | |
| | def worker_api_get_status(self): |
| | model_names = set() |
| | speed = 0 |
| | queue_length = 0 |
| |
|
| | for w_name in self.worker_info: |
| | worker_status = self.get_worker_status(w_name) |
| | if worker_status is not None: |
| | model_names.update(worker_status["model_names"]) |
| | speed += worker_status["speed"] |
| | queue_length += worker_status["queue_length"] |
| |
|
| | return { |
| | "model_names": list(model_names), |
| | "speed": speed, |
| | "queue_length": queue_length, |
| | } |
| |
|
| |
|
| | app = FastAPI() |
| |
|
| |
|
| | @app.post("/register_worker") |
| | async def register_worker(request: Request): |
| | data = await request.json() |
| | controller.register_worker( |
| | data["worker_name"], data["check_heart_beat"], |
| | data.get("worker_status", None)) |
| |
|
| |
|
| | @app.post("/refresh_all_workers") |
| | async def refresh_all_workers(): |
| | models = controller.refresh_all_workers() |
| |
|
| |
|
| | @app.post("/list_models") |
| | async def list_models(): |
| | models = controller.list_models() |
| | return {"models": models} |
| |
|
| |
|
| | @app.post("/get_worker_address") |
| | async def get_worker_address(request: Request): |
| | data = await request.json() |
| | addr = controller.get_worker_address(data["model"]) |
| | return {"address": addr} |
| |
|
| |
|
| | @app.post("/receive_heart_beat") |
| | async def receive_heart_beat(request: Request): |
| | data = await request.json() |
| | exist = controller.receive_heart_beat( |
| | data["worker_name"], data["queue_length"]) |
| | return {"exist": exist} |
| |
|
| |
|
| | @app.post("/worker_generate_stream") |
| | async def worker_api_generate_stream(request: Request): |
| | params = await request.json() |
| | generator = controller.worker_api_generate_stream(params) |
| | return StreamingResponse(generator) |
| |
|
| |
|
| | @app.post("/worker_get_status") |
| | async def worker_api_get_status(request: Request): |
| | return controller.worker_api_get_status() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--host", type=str, default="localhost") |
| | parser.add_argument("--port", type=int, default=21001) |
| | parser.add_argument("--dispatch-method", type=str, choices=[ |
| | "lottery", "shortest_queue"], default="shortest_queue") |
| | args = parser.parse_args() |
| | logger.info(f"args: {args}") |
| |
|
| | controller = Controller(args.dispatch_method) |
| | uvicorn.run(app, host=args.host, port=args.port, log_level="info") |
| |
|