File size: 4,433 Bytes
88ef5ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
064894c
88ef5ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from typing import Dict, List, Any
import os
import torch
from PIL import Image
import dotenv
import base64
import io
from diffusers import DiffusionPipeline  # pyright: ignore[reportPrivateImportUsage]

dotenv.load_dotenv()

def convert_b64_to_image(from_str: str) -> Image.Image:
    print(">>> call convert_b64_to_image", flush=True)
    try:
        data: bytes = base64.b64decode(from_str)
        with io.BytesIO(data) as bio:
            imgfile = Image.open(bio, formats=["PNG"])
            imgfile.load()
            return imgfile
            
    except Exception as e:
        print(e, flush=True)
        raise e
    
def convert_image_to_b64(from_img: Image.Image) -> str:
    print(">>> call convert_image_to_b64", flush=True)
    try:
        with io.BytesIO() as buffer:
            from_img.save(buffer, format="PNG")
            byte_data: bytes = buffer.getvalue()
            return base64.b64encode(byte_data).decode("utf-8")
    except Exception as e:
        print(e, flush=True)
        raise e

class HFMultiViewGen:

    def __init__(self,
                 hf_token: str,
                 mv_model: str = "maple-shaft/zero123plus-v1.2",
                 mv_custom_pipeline: str = "sudo-ai/zero123plus-pipeline",
                 gen_custom_pipeline: str = "",
                 repo_dir: str = "/repository",
                 debug: bool = False):
        self.debug = debug
        self.hf_token = hf_token
        self.mv_model = mv_model
        self.mv_custom_pipeline = mv_custom_pipeline
        self.repo_dir = repo_dir

        print(f"torch.cuda.is_available() = {torch.cuda.is_available()}")
        torch.cuda.synchronize()
        print("GPU SYNC OK", flush=True)

        self.pipe = DiffusionPipeline.from_pretrained(
            self.mv_model,
            cache_dir=self.repo_dir,
            token=self.hf_token,
            custom_pipeline=self.mv_custom_pipeline,
            dtype=torch.float16
        ).to("cuda")

    def generate_multiview(self, initial: Image.Image) -> dict[str, Image.Image]:
        print(">>> generate_multiview", flush=True)
        
        print("allocated second pipe to gpu", flush=True)
        # --- prepare image properly ---
        img = initial.convert("RGB")

        print("converted the image to RGB", flush=True)
        
        mv_result : List[Image.Image] = self.pipe(
            image=img,
            width=640,
            height=960,
            num_inference_steps=28,
            guidance_scale=4.0,
            num_images_per_prompt=1
        ).images  # pyright: ignore[reportCallIssue]

        print("mv_result", repr(mv_result), flush=True)

        # The resulting file comes back as a 2x3 tiled PNG image, we will need to split it into a set of images
        tile_w = 320.0 # img.width / 2.0
        tile_h = 320.0 # img.height / 3.0
        right_tile = (tile_w, 0.0, tile_w * 2.0, tile_h)
        back_tile = (tile_w, tile_h, tile_w * 2.0, tile_h * 2.0)
        left_tile = (0, tile_h * 2.0, tile_w, tile_h * 3.0)
        ret = {
            "front": img,
            "right": mv_result[0].crop(right_tile),
            "back": mv_result[0].crop(back_tile),
            "left": mv_result[0].crop(left_tile)
        }
    
        return ret

class EndpointHandler():
    def __init__(self, path=""):
        self.hf_token = os.environ["HUGGINGFACE_TOKEN"]
        self.repo_dir = os.environ["HF_HUB_CACHE"] if not path else path
        self.hf_gen = HFMultiViewGen(hf_token=self.hf_token, repo_dir=self.repo_dir)

    def convert(self, fromval: dict[str, Image.Image]) -> dict[str, str]:
        ret: dict[str, str] = {}
        for k,v in fromval.items():
            ret[k] = convert_image_to_b64(v)
        return ret
    
    def __call__(self, data: Dict[str, Any]):
        print("Entered __call__!!! ", repr(data), flush=True)
        ret: dict[str, Any] = {}
        try:
            img_str = data['inputs']
            print(f"Initial image: {img_str}", flush=True)
            img: Image.Image = convert_b64_to_image(img_str)
            print("Converted to image", repr(img), flush=True)
            mv: dict[str, Image.Image] = self.hf_gen.generate_multiview(initial=img)
            print(f"Mv Image: {mv}", flush=True)
            mv_str: Dict[str,str] = self.convert(mv)
            ret["output"] = mv_str
            return ret
        except Exception as e:
            print(e)
            raise e