import os import subprocess from io import BytesIO from multiprocessing.connection import Listener from os import chmod from os.path import abspath from pathlib import Path from PIL.JpegImagePlugin import JpegImageFile from pipelines.models import TextToImageRequest from pipeline import load_pipeline, infer SOCKET = abspath(Path(__file__).parent.parent / "inferences.sock") def run_command(command): process = subprocess.run(command, shell=True) if process.returncode != 0: raise Exception(f"Command failed: {command}") def install_dependencies(): # Install torch and torchvision from PyTorch index run_command("pip install --index-url https://download.pytorch.org/whl/cu121 torch==2.2.2 torchvision==0.17.2") # Install other dependencies run_command("pip install accelerate==0.31.0 numpy==1.26.4 xformers==0.0.25.post1 triton==2.2.0 transformers==4.41.2") # Download and install the .whl file for stable_fast run_command("wget https://github.com/chengzeyi/stable-fast/releases/download/v1.0.5/stable_fast-1.0.5+torch222cu121-cp310-cp310-manylinux2014_x86_64.whl") run_command("pip install stable_fast-1.0.5+torch222cu121-cp310-cp310-manylinux2014_x86_64.whl") def main(): # Install dependencies before running the pipeline print("Installing dependencies...") install_dependencies() print(f"Loading pipeline") pipeline = load_pipeline() print(f"Pipeline loaded") print(f"Creating socket at '{SOCKET}'") with Listener(SOCKET) as listener: chmod(SOCKET, 0o777) print(f"Awaiting connections") with listener.accept() as connection: print(f"Connected") while True: try: request = TextToImageRequest.model_validate_json(connection.recv_bytes().decode("utf-8")) except EOFError: print(f"Inference socket exiting") return image = infer(request, pipeline) data = BytesIO() image.save(data, format=JpegImageFile.format) packet = data.getvalue() connection.send_bytes(packet) if __name__ == '__main__': main()