|
|
|
|
|
|
|
|
import os |
|
|
from typing import Any, Dict |
|
|
from PIL import Image |
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
from huggingface_inference_toolkit.logging import logger |
|
|
|
|
|
dist.init_process_group() |
|
|
torch.cuda.set_device(dist.get_rank()) |
|
|
|
|
|
from para_attn.context_parallel import init_context_parallel_mesh |
|
|
from para_attn.context_parallel.diffusers_adapters import parallelize_pipe |
|
|
from para_attn.parallel_vae.diffusers_adapters import parallelize_vae |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self,path=""): |
|
|
|
|
|
|
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Image.Image: |
|
|
logger.info(f"Received incoming request with {data=}") |
|
|
|
|
|
if "inputs" in data and isinstance(data["inputs"], str): |
|
|
prompt = data.pop("inputs") |
|
|
elif "prompt" in data and isinstance(data["prompt"], str): |
|
|
prompt = data.pop("prompt") |
|
|
else: |
|
|
raise ValueError( |
|
|
"Provided input body must contain either the key `inputs` or `prompt` with the" |
|
|
" prompt to use for the image generation, and it needs to be a non-empty string." |
|
|
) |
|
|
|
|
|
return "1" |
|
|
|