File size: 4,186 Bytes
c446951
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import subprocess

import typer

import docker

docker_client = docker.from_env()


def ask_user_to_kill_container(c):
    name = c.attrs.get("Name", "")
    env_vars = c.attrs.get("Config", {}).get("Env", {})
    port = 9001
    for var in env_vars:
        if var.startswith("PORT="):
            port = var.split("=")[1]
    should_delete = typer.confirm(
        f" An inference server is already running in container {name} on port {port}. Are you sure you want to delete it?"
    )
    return should_delete


def is_inference_server_container(container):
    image_tags = container.image.tags
    for t in image_tags:
        if t.startswith("roboflow/roboflow-inference-server"):
            return True
    return False


def handle_existing_containers(containers):
    has_existing_containers = False
    for c in containers:
        if is_inference_server_container(c):
            has_existing_containers = True
            if c.attrs.get("State", {}).get("Status", "").lower() == "running":
                should_kill = ask_user_to_kill_container(c)
                if should_kill:
                    c.kill()
                    has_existing_containers = False
    return has_existing_containers


def find_existing_containers():
    containers = []
    for c in docker_client.containers.list():
        if is_inference_server_container(c):
            if c.attrs.get("State", {}).get("Status", "").lower() == "running":
                containers.append(c)
    return containers


def get_image():
    try:
        subprocess.check_output("nvidia-smi")
        print("GPU detected. Using a GPU image.")
        return "roboflow/roboflow-inference-server-gpu:latest"
    except:
        print("No GPU detected. Using a CPU image.")
        return "roboflow/roboflow-inference-server-cpu:latest"


def start_inference_container(
    api_key,
    image=None,
    port=9001,
    labels=None,
    project="roboflow-platform",
    metrics_enabled=True,
    device_id=None,
    num_workers=1,
):
    containers = find_existing_containers()
    if len(containers) > 0:
        still_has_containers = handle_existing_containers(containers)
        if still_has_containers:
            print("Please kill the existing containers and try again.")
            return

    if image is None:
        image = get_image()

    device_requests = None
    privileged = False
    if "gpu" in image:
        privileged = True
        device_requests = (
            [docker.types.DeviceRequest(device_ids=["all"], capabilities=[["gpu"]])],
        )

    print(f"Starting inference server container...")
    docker_client.containers.run(
        image=image,
        privileged=privileged,
        detach=True,
        labels=labels,
        ports={"9001": port},
        # network="host",
        device_requests=device_requests,
        environment=[
            "HOST=0.0.0.0",
            f"PORT={port}",
            f"PROJECT={project}",
            f"METRICS_ENABLED={metrics_enabled}",
            f"DEVICE_ID={device_id}",
            f"API_KEY={api_key}",
            f"NUM_WORKERS={num_workers}",
        ],
    )


def check_inference_server_status():
    containers = find_existing_containers()
    if len(containers) > 0:
        for c in containers:
            container_name = c.attrs.get("Name", "")
            created = c.attrs.get("Created", "")
            exposed_port = list(c.attrs.get("Config").get("ExposedPorts", {}).keys())[0]
            status = c.attrs.get("State", {}).get("Status", "unknown")
            image = c.attrs.get("Image", "")
            container_status_message = """
Container Name: {container_name}
Created: {created}
Exposed Port: {exposed_port}
Status: {status}
Image: {image}
            """
            print(
                container_status_message.format(
                    container_name=container_name,
                    created=created,
                    exposed_port=exposed_port,
                    status=status,
                    image=image,
                )
            )
            return
    print("No inference server container running.")


if __name__ == "__main__":
    start_inference_container("my_api_key")