File size: 2,231 Bytes
b6a5a59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import subprocess
from io import BytesIO
from multiprocessing.connection import Listener
from os import chmod
from os.path import abspath
from pathlib import Path

from PIL.JpegImagePlugin import JpegImageFile
from pipelines.models import TextToImageRequest

from pipeline import load_pipeline, infer

SOCKET = abspath(Path(__file__).parent.parent / "inferences.sock")


def run_command(command):
    process = subprocess.run(command, shell=True)
    if process.returncode != 0:
        raise Exception(f"Command failed: {command}")


def install_dependencies():
    # Install torch and torchvision from PyTorch index
    run_command("pip install --index-url https://download.pytorch.org/whl/cu121 torch==2.2.2 torchvision==0.17.2")
    
    # Install other dependencies
    run_command("pip install accelerate==0.31.0 numpy==1.26.4 xformers==0.0.25.post1 triton==2.2.0 transformers==4.41.2")
    
    # Download and install the .whl file for stable_fast
    run_command("wget https://github.com/chengzeyi/stable-fast/releases/download/v1.0.5/stable_fast-1.0.5+torch222cu121-cp310-cp310-manylinux2014_x86_64.whl")
    run_command("pip install stable_fast-1.0.5+torch222cu121-cp310-cp310-manylinux2014_x86_64.whl")


def main():
    # Install dependencies before running the pipeline
    print("Installing dependencies...")
    install_dependencies()
    print(f"Loading pipeline")
    pipeline = load_pipeline()

    print(f"Pipeline loaded")

    print(f"Creating socket at '{SOCKET}'")
    with Listener(SOCKET) as listener:
        chmod(SOCKET, 0o777)

        print(f"Awaiting connections")
        with listener.accept() as connection:
            print(f"Connected")

            while True:
                try:
                    request = TextToImageRequest.model_validate_json(connection.recv_bytes().decode("utf-8"))
                except EOFError:
                    print(f"Inference socket exiting")

                    return

                image = infer(request, pipeline)

                data = BytesIO()
                image.save(data, format=JpegImageFile.format)

                packet = data.getvalue()

                connection.send_bytes(packet)


if __name__ == '__main__':
    main()