File size: 1,469 Bytes
01dbf80
3db4312
 
 
 
 
01dbf80
3db4312
 
 
 
 
 
 
01dbf80
 
 
 
3db4312
01dbf80
 
3db4312
 
 
01dbf80
3db4312
 
 
 
 
 
 
 
 
 
01dbf80
3db4312
 
 
 
 
 
 
01dbf80
3db4312
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import atexit
from io import BytesIO
from multiprocessing.connection import Listener
from os import chmod, remove
from os.path import abspath, exists
from pathlib import Path
import torch

from PIL.JpegImagePlugin import JpegImageFile
from pipelines.models import TextToImageRequest
from pipeline import load_pipeline, infer
SOCKET = abspath(Path(__file__).parent.parent / "inferences.sock")


def at_exit():
    torch.cuda.empty_cache()


def main():
    atexit.register(at_exit)

    print(f"Loading pipeline")
    pipeline = load_pipeline()

    print(f"Pipeline loaded, creating socket at '{SOCKET}'")

    if exists(SOCKET):
        remove(SOCKET)

    with Listener(SOCKET) as listener:
        chmod(SOCKET, 0o777)

        print(f"Awaiting connections")
        with listener.accept() as connection:
            print(f"Connected")
            generator = torch.Generator("cuda")
            while True:
                try:
                    request = TextToImageRequest.model_validate_json(connection.recv_bytes().decode("utf-8"))
                except EOFError:
                    print(f"Inference socket exiting")

                    return
                image = infer(request, pipeline, generator.manual_seed(request.seed))
                data = BytesIO()
                image.save(data, format=JpegImageFile.format)

                packet = data.getvalue()

                connection.send_bytes(packet)


if __name__ == '__main__':
    main()