File size: 2,874 Bytes
0dd6c2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""Model Server Launcher - Supports both LlamaCPP and vLLM backends"""

import os
import subprocess
import sys
from dataclasses import asdict

from trl import TrlParser

from linalg_zero.config.data import DistillationConfig, LlamaCppServerConfig, VllmServerConfig
from linalg_zero.shared.utils import get_logger, setup_logging


def launch_llamacpp(config: LlamaCppServerConfig) -> None:
    """Launch LlamaCPP server"""
    print("Starting llama.cpp server...")

    config_dict = asdict(config)
    args = ["uv", "run", "python3", "-m", "llama_cpp.server"]
    for k, v in config_dict.items():
        if v is True:
            args.extend([f"--{k}"])
        elif v is None:
            continue
        else:
            args.extend([f"--{k}", str(v)])

    print(f"Command: {' '.join(args)}")
    _ = subprocess.run(args, check=True)  # noqa: S603


def launch_vllm(config: VllmServerConfig) -> None:
    """Launch vLLM server"""
    print("Starting vLLM server...")

    config_dict = asdict(config)
    args = ["uv", "run", "vllm", "serve"]
    for k, v in config_dict.items():
        if k == "model":
            args.extend([str(v)])
        elif v is True:
            args.extend([f"--{k}"])
        elif v is None:
            continue
        else:
            args.extend([f"--{k}", str(v)])

    print(f"Command: {' '.join(args)}")
    _ = subprocess.run(args)  # noqa: S603


def main() -> None:
    """Main function"""
    setup_logging(file_suffix="distillation_server.log")
    logger = get_logger(__name__)
    logger.info(f"Using the configuration file stored at: {os.path.abspath(sys.argv[2])}")

    # First parse to get the backend type
    backend = os.environ["INFERENCE_BACKEND"]

    # Determine which server config to use based on backend
    if backend == "llamacpp":
        # Parse again with LlamaCppServerConfig
        parser = TrlParser(dataclass_types=[DistillationConfig, LlamaCppServerConfig])
        llama_cpp_config: LlamaCppServerConfig = parser.parse_args_and_config()[1]

        logger.info("LlamaCPP Server configuration:")
        for field_name in llama_cpp_config.__dataclass_fields__:
            value = getattr(llama_cpp_config, field_name)
            logger.info(f"  {field_name}: {value}")

        launch_llamacpp(llama_cpp_config)

    elif backend == "vllm":
        # Parse again with VllmServerConfig
        parser = TrlParser(dataclass_types=[DistillationConfig, VllmServerConfig])
        vllm_config: VllmServerConfig = parser.parse_args_and_config()[1]

        logger.info("vLLM Server configuration:")
        for field_name in vllm_config.__dataclass_fields__:
            value = getattr(vllm_config, field_name)
            logger.info(f"  {field_name}: {value}")

        launch_vllm(vllm_config)

    else:
        logger.error(f"Unknown backend: {backend}")


if __name__ == "__main__":
    main()