sn39-19 / src /main.py
ahuhu
push
b6a5a59
import os
import subprocess
from io import BytesIO
from multiprocessing.connection import Listener
from os import chmod
from os.path import abspath
from pathlib import Path
from PIL.JpegImagePlugin import JpegImageFile
from pipelines.models import TextToImageRequest
from pipeline import load_pipeline, infer
SOCKET = abspath(Path(__file__).parent.parent / "inferences.sock")
def run_command(command):
process = subprocess.run(command, shell=True)
if process.returncode != 0:
raise Exception(f"Command failed: {command}")
def install_dependencies():
# Install torch and torchvision from PyTorch index
run_command("pip install --index-url https://download.pytorch.org/whl/cu121 torch==2.2.2 torchvision==0.17.2")
# Install other dependencies
run_command("pip install accelerate==0.31.0 numpy==1.26.4 xformers==0.0.25.post1 triton==2.2.0 transformers==4.41.2")
# Download and install the .whl file for stable_fast
run_command("wget https://github.com/chengzeyi/stable-fast/releases/download/v1.0.5/stable_fast-1.0.5+torch222cu121-cp310-cp310-manylinux2014_x86_64.whl")
run_command("pip install stable_fast-1.0.5+torch222cu121-cp310-cp310-manylinux2014_x86_64.whl")
def main():
# Install dependencies before running the pipeline
print("Installing dependencies...")
install_dependencies()
print(f"Loading pipeline")
pipeline = load_pipeline()
print(f"Pipeline loaded")
print(f"Creating socket at '{SOCKET}'")
with Listener(SOCKET) as listener:
chmod(SOCKET, 0o777)
print(f"Awaiting connections")
with listener.accept() as connection:
print(f"Connected")
while True:
try:
request = TextToImageRequest.model_validate_json(connection.recv_bytes().decode("utf-8"))
except EOFError:
print(f"Inference socket exiting")
return
image = infer(request, pipeline)
data = BytesIO()
image.save(data, format=JpegImageFile.format)
packet = data.getvalue()
connection.send_bytes(packet)
if __name__ == '__main__':
main()