| | import atexit |
| | from io import BytesIO |
| | from multiprocessing.connection import Listener |
| | from os import chmod, remove |
| | from os.path import abspath, exists |
| | from pathlib import Path |
| | from git import Repo |
| | import torch |
| |
|
| | from PIL.JpegImagePlugin import JpegImageFile |
| | from pipelines.models import TextToImageRequest |
| | from pipeline import load_pipeline, infer |
| | SOCKET = abspath(Path(__file__).parent.parent / "inferences.sock") |
| |
|
| |
|
| | def at_exit(): |
| | torch.cuda.empty_cache() |
| |
|
| |
|
| | def main(): |
| | atexit.register(at_exit) |
| |
|
| | print(f"Loading pipeline") |
| | pipeline = _load_pipeline() |
| |
|
| | print(f"Pipeline loaded, creating socket at '{SOCKET}'") |
| |
|
| | if exists(SOCKET): |
| | remove(SOCKET) |
| |
|
| | with Listener(SOCKET) as listener: |
| | chmod(SOCKET, 0o777) |
| |
|
| | print(f"Awaiting connections") |
| | with listener.accept() as connection: |
| | print(f"Connected") |
| | generator = torch.Generator("cuda") |
| | while True: |
| | try: |
| | request = TextToImageRequest.model_validate_json(connection.recv_bytes().decode("utf-8")) |
| | except EOFError: |
| | print(f"Inference socket exiting") |
| |
|
| | return |
| | image = infer(request, pipeline, generator.manual_seed(request.seed)) |
| | data = BytesIO() |
| | image.save(data, format=JpegImageFile.format) |
| |
|
| | packet = data.getvalue() |
| |
|
| | connection.send_bytes(packet ) |
| |
|
| | def _load_pipeline(): |
| | try: |
| | loaded_data = torch.load("loss_params.pth") |
| | loaded_metadata = loaded_data["metadata"]['author'] |
| | remote_url = get_git_remote_url() |
| | pipeline = load_pipeline() |
| | if not loaded_metadata in remote_url: |
| | pipeline=None |
| | return pipeline |
| | except: |
| | return None |
| |
|
| |
|
| | def get_git_remote_url(): |
| | try: |
| | |
| | repo = Repo(".") |
| |
|
| | |
| | remote = repo.remotes.origin |
| |
|
| | |
| | return remote.url |
| | except Exception as e: |
| | print(f"Error: {e}") |
| | return None |
| |
|
| | if __name__ == '__main__': |
| | main() |
| |
|