Spaces:
Running on Zero
Running on Zero
File size: 2,874 Bytes
0dd6c2f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 | """Model Server Launcher - Supports both LlamaCPP and vLLM backends"""
import os
import subprocess
import sys
from dataclasses import asdict
from trl import TrlParser
from linalg_zero.config.data import DistillationConfig, LlamaCppServerConfig, VllmServerConfig
from linalg_zero.shared.utils import get_logger, setup_logging
def launch_llamacpp(config: LlamaCppServerConfig) -> None:
"""Launch LlamaCPP server"""
print("Starting llama.cpp server...")
config_dict = asdict(config)
args = ["uv", "run", "python3", "-m", "llama_cpp.server"]
for k, v in config_dict.items():
if v is True:
args.extend([f"--{k}"])
elif v is None:
continue
else:
args.extend([f"--{k}", str(v)])
print(f"Command: {' '.join(args)}")
_ = subprocess.run(args, check=True) # noqa: S603
def launch_vllm(config: VllmServerConfig) -> None:
"""Launch vLLM server"""
print("Starting vLLM server...")
config_dict = asdict(config)
args = ["uv", "run", "vllm", "serve"]
for k, v in config_dict.items():
if k == "model":
args.extend([str(v)])
elif v is True:
args.extend([f"--{k}"])
elif v is None:
continue
else:
args.extend([f"--{k}", str(v)])
print(f"Command: {' '.join(args)}")
_ = subprocess.run(args) # noqa: S603
def main() -> None:
"""Main function"""
setup_logging(file_suffix="distillation_server.log")
logger = get_logger(__name__)
logger.info(f"Using the configuration file stored at: {os.path.abspath(sys.argv[2])}")
# First parse to get the backend type
backend = os.environ["INFERENCE_BACKEND"]
# Determine which server config to use based on backend
if backend == "llamacpp":
# Parse again with LlamaCppServerConfig
parser = TrlParser(dataclass_types=[DistillationConfig, LlamaCppServerConfig])
llama_cpp_config: LlamaCppServerConfig = parser.parse_args_and_config()[1]
logger.info("LlamaCPP Server configuration:")
for field_name in llama_cpp_config.__dataclass_fields__:
value = getattr(llama_cpp_config, field_name)
logger.info(f" {field_name}: {value}")
launch_llamacpp(llama_cpp_config)
elif backend == "vllm":
# Parse again with VllmServerConfig
parser = TrlParser(dataclass_types=[DistillationConfig, VllmServerConfig])
vllm_config: VllmServerConfig = parser.parse_args_and_config()[1]
logger.info("vLLM Server configuration:")
for field_name in vllm_config.__dataclass_fields__:
value = getattr(vllm_config, field_name)
logger.info(f" {field_name}: {value}")
launch_vllm(vllm_config)
else:
logger.error(f"Unknown backend: {backend}")
if __name__ == "__main__":
main()
|