| | import os |
| | import torch |
| | from PIL import Image |
| | from diffusers import FluxControlNetModel |
| | from diffusers.pipelines import FluxControlNetPipeline |
| | from io import BytesIO |
| | import logging |
| |
|
| | class EndpointHandler: |
| | def __init__(self, model_dir="huyai123/Flux.1-dev-Image-Upscaler"): |
| | |
| | os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128" |
| |
|
| | |
| | HF_TOKEN = os.getenv('HF_TOKEN') |
| | if not HF_TOKEN: |
| | raise ValueError("HF_TOKEN environment variable is not set") |
| |
|
| | logging.basicConfig(level=logging.INFO) |
| | logging.info("Using HF_TOKEN") |
| |
|
| | |
| | torch.cuda.empty_cache() |
| |
|
| | |
| | self.controlnet = FluxControlNetModel.from_pretrained( |
| | model_dir, torch_dtype=torch.float16, use_auth_token=HF_TOKEN |
| | ) |
| | self.pipe = FluxControlNetPipeline.from_pretrained( |
| | "black-forest-labs/FLUX.1-dev", |
| | controlnet=self.controlnet, |
| | torch_dtype=torch.float16, |
| | use_auth_token=HF_TOKEN |
| | ) |
| | self.pipe.to("cuda") |
| | self.pipe.enable_attention_slicing("auto") |
| | self.pipe.enable_sequential_cpu_offload() |
| | self.pipe.enable_memory_efficient_attention() |
| |
|
| | def preprocess(self, data): |
| | image_file = data.get("control_image", None) |
| | if not image_file: |
| | raise ValueError("Missing control_image in input.") |
| | image = Image.open(image_file) |
| | return image.resize((512, 512)) |
| |
|
| | def postprocess(self, output): |
| | buffer = BytesIO() |
| | output.save(buffer, format="PNG") |
| | buffer.seek(0) |
| | return buffer |
| |
|
| | def inference(self, data): |
| | control_image = self.preprocess(data) |
| | torch.cuda.empty_cache() |
| | output_image = self.pipe( |
| | prompt=data.get("prompt", ""), |
| | control_image=control_image, |
| | controlnet_conditioning_scale=0.5, |
| | num_inference_steps=10, |
| | height=control_image.size[1], |
| | width=control_image.size[0], |
| | ).images[0] |
| | return self.postprocess(output_image) |
| |
|
| | if __name__ == "__main__": |
| | data = {'control_image': 'path/to/your/image.png', 'prompt': 'Your prompt here'} |
| | handler = EndpointHandler() |
| | output = handler.inference(data) |
| | print(output) |
| |
|