| | import os |
| | import subprocess |
| | from io import BytesIO |
| | from multiprocessing.connection import Listener |
| | from os import chmod |
| | from os.path import abspath |
| | from pathlib import Path |
| |
|
| | from PIL.JpegImagePlugin import JpegImageFile |
| | from pipelines.models import TextToImageRequest |
| |
|
| | from pipeline import load_pipeline, infer |
| |
|
| | SOCKET = abspath(Path(__file__).parent.parent / "inferences.sock") |
| |
|
| |
|
| | def run_command(command): |
| | process = subprocess.run(command, shell=True) |
| | if process.returncode != 0: |
| | raise Exception(f"Command failed: {command}") |
| |
|
| |
|
| | def install_dependencies(): |
| | |
| | run_command("pip install --index-url https://download.pytorch.org/whl/cu121 torch==2.2.2 torchvision==0.17.2") |
| | |
| | |
| | run_command("pip install accelerate==0.31.0 numpy==1.26.4 xformers==0.0.25.post1 triton==2.2.0 transformers==4.41.2") |
| | |
| | |
| | run_command("wget https://github.com/chengzeyi/stable-fast/releases/download/v1.0.5/stable_fast-1.0.5+torch222cu121-cp310-cp310-manylinux2014_x86_64.whl") |
| | run_command("pip install stable_fast-1.0.5+torch222cu121-cp310-cp310-manylinux2014_x86_64.whl") |
| |
|
| |
|
| | def main(): |
| | |
| | print("Installing dependencies...") |
| | install_dependencies() |
| | print(f"Loading pipeline") |
| | pipeline = load_pipeline() |
| |
|
| | print(f"Pipeline loaded") |
| |
|
| | print(f"Creating socket at '{SOCKET}'") |
| | with Listener(SOCKET) as listener: |
| | chmod(SOCKET, 0o777) |
| |
|
| | print(f"Awaiting connections") |
| | with listener.accept() as connection: |
| | print(f"Connected") |
| |
|
| | while True: |
| | try: |
| | request = TextToImageRequest.model_validate_json(connection.recv_bytes().decode("utf-8")) |
| | except EOFError: |
| | print(f"Inference socket exiting") |
| |
|
| | return |
| |
|
| | image = infer(request, pipeline) |
| |
|
| | data = BytesIO() |
| | image.save(data, format=JpegImageFile.format) |
| |
|
| | packet = data.getvalue() |
| |
|
| | connection.send_bytes(packet) |
| |
|
| |
|
| | if __name__ == '__main__': |
| | main() |
| |
|