from typing import Dict, List, Any import os import torch from PIL import Image import dotenv import base64 import io from diffusers import DiffusionPipeline # pyright: ignore[reportPrivateImportUsage] dotenv.load_dotenv() def convert_b64_to_image(from_str: str) -> Image.Image: print(">>> call convert_b64_to_image", flush=True) try: data: bytes = base64.b64decode(from_str) with io.BytesIO(data) as bio: imgfile = Image.open(bio, formats=["PNG"]) imgfile.load() return imgfile except Exception as e: print(e, flush=True) raise e def convert_image_to_b64(from_img: Image.Image) -> str: print(">>> call convert_image_to_b64", flush=True) try: with io.BytesIO() as buffer: from_img.save(buffer, format="PNG") byte_data: bytes = buffer.getvalue() return base64.b64encode(byte_data).decode("utf-8") except Exception as e: print(e, flush=True) raise e class HFMultiViewGen: def __init__(self, hf_token: str, mv_model: str = "maple-shaft/zero123plus-v1.2", mv_custom_pipeline: str = "sudo-ai/zero123plus-pipeline", gen_custom_pipeline: str = "", repo_dir: str = "/repository", debug: bool = False): self.debug = debug self.hf_token = hf_token self.mv_model = mv_model self.mv_custom_pipeline = mv_custom_pipeline self.repo_dir = repo_dir print(f"torch.cuda.is_available() = {torch.cuda.is_available()}") torch.cuda.synchronize() print("GPU SYNC OK", flush=True) self.pipe = DiffusionPipeline.from_pretrained( self.mv_model, cache_dir=self.repo_dir, token=self.hf_token, custom_pipeline=self.mv_custom_pipeline, dtype=torch.float16 ).to("cuda") def generate_multiview(self, initial: Image.Image) -> dict[str, Image.Image]: print(">>> generate_multiview", flush=True) print("allocated second pipe to gpu", flush=True) # --- prepare image properly --- img = initial.convert("RGB") print("converted the image to RGB", flush=True) mv_result : List[Image.Image] = self.pipe( image=img, width=640, height=960, num_inference_steps=28, guidance_scale=4.0, num_images_per_prompt=1 ).images # pyright: ignore[reportCallIssue] print("mv_result", repr(mv_result), flush=True) # The resulting file comes back as a 2x3 tiled PNG image, we will need to split it into a set of images tile_w = 320.0 # img.width / 2.0 tile_h = 320.0 # img.height / 3.0 right_tile = (tile_w, 0.0, tile_w * 2.0, tile_h) back_tile = (tile_w, tile_h, tile_w * 2.0, tile_h * 2.0) left_tile = (0, tile_h * 2.0, tile_w, tile_h * 3.0) ret = { "front": img, "right": mv_result[0].crop(right_tile), "back": mv_result[0].crop(back_tile), "left": mv_result[0].crop(left_tile) } return ret class EndpointHandler(): def __init__(self, path=""): self.hf_token = os.environ["HUGGINGFACE_TOKEN"] self.repo_dir = os.environ["HF_HUB_CACHE"] if not path else path self.hf_gen = HFMultiViewGen(hf_token=self.hf_token, repo_dir=self.repo_dir) def convert(self, fromval: dict[str, Image.Image]) -> dict[str, str]: ret: dict[str, str] = {} for k,v in fromval.items(): ret[k] = convert_image_to_b64(v) return ret def __call__(self, data: Dict[str, Any]): print("Entered __call__!!! ", repr(data), flush=True) ret: dict[str, Any] = {} try: img_str = data['inputs'] print(f"Initial image: {img_str}", flush=True) img: Image.Image = convert_b64_to_image(img_str) print("Converted to image", repr(img), flush=True) mv: dict[str, Image.Image] = self.hf_gen.generate_multiview(initial=img) print(f"Mv Image: {mv}", flush=True) mv_str: Dict[str,str] = self.convert(mv) ret["output"] = mv_str return ret except Exception as e: print(e) raise e