harshabasavarajbeth commited on
Commit
87a8857
·
verified ·
1 Parent(s): 802a50d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +165 -0
app.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import zipfile
3
+ from typing import Dict
4
+
5
+ import torch
6
+ from fastapi import FastAPI, File, UploadFile
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from fastapi.responses import JSONResponse, StreamingResponse
9
+ from PIL import Image
10
+ from rembg import remove
11
+ from diffusers import DiffusionPipeline
12
+
13
+
14
+ app = FastAPI(title="Zero123++ Inference API")
15
+
16
+ app.add_middleware(
17
+ CORSMiddleware,
18
+ allow_origins=["*"],
19
+ allow_credentials=True,
20
+ allow_methods=["*"],
21
+ allow_headers=["*"],
22
+ )
23
+
24
+ MODEL_ID = "sudo-ai/zero123plus-v1.2"
25
+ CUSTOM_PIPELINE = "sudo-ai/zero123plus-pipeline"
26
+
27
+ pipeline = None
28
+
29
+
30
+ def load_pipeline():
31
+ global pipeline
32
+
33
+ if pipeline is not None:
34
+ return pipeline
35
+
36
+ if not torch.cuda.is_available():
37
+ raise RuntimeError(
38
+ "CUDA GPU is not available. Please enable GPU hardware on the Hugging Face Space."
39
+ )
40
+
41
+ pipe = DiffusionPipeline.from_pretrained(
42
+ MODEL_ID,
43
+ custom_pipeline=CUSTOM_PIPELINE,
44
+ torch_dtype=torch.float16,
45
+ trust_remote_code=True,
46
+ )
47
+
48
+ pipe.to("cuda")
49
+ pipe.enable_attention_slicing()
50
+
51
+ pipeline = pipe
52
+ return pipeline
53
+
54
+
55
+ def image_to_bytes(image: Image.Image, fmt: str = "PNG") -> bytes:
56
+ buffer = io.BytesIO()
57
+ image.save(buffer, format=fmt)
58
+ buffer.seek(0)
59
+ return buffer.getvalue()
60
+
61
+
62
+ def crop_selected_views(grid: Image.Image) -> Dict[str, Image.Image]:
63
+ """
64
+ Zero123++ output is expected as a 2-column x 3-row grid.
65
+ We keep views 1, 3, 5, and 6 for LGM.
66
+ """
67
+
68
+ grid = grid.convert("RGB")
69
+ w, h = grid.size
70
+
71
+ cols, rows = 2, 3
72
+ tile_w, tile_h = w // cols, h // rows
73
+
74
+ selected = {
75
+ 1: "front_right",
76
+ 3: "right",
77
+ 5: "back_right",
78
+ 6: "back_left",
79
+ }
80
+
81
+ outputs = {}
82
+
83
+ for idx_1based, name in selected.items():
84
+ row = (idx_1based - 1) // cols
85
+ col = (idx_1based - 1) % cols
86
+
87
+ box = (
88
+ col * tile_w,
89
+ row * tile_h,
90
+ (col + 1) * tile_w,
91
+ (row + 1) * tile_h,
92
+ )
93
+
94
+ tile = grid.crop(box).resize((256, 256), Image.LANCZOS)
95
+
96
+ # Remove background and paste on white
97
+ tile_rgba = remove(tile.convert("RGBA"))
98
+ white_bg = Image.new("RGBA", tile_rgba.size, (255, 255, 255, 255))
99
+ white_bg.paste(tile_rgba, mask=tile_rgba.split()[3])
100
+
101
+ outputs[f"view_{idx_1based}_{name}.png"] = white_bg.convert("RGB")
102
+
103
+ return outputs
104
+
105
+
106
+ @app.get("/")
107
+ def root():
108
+ return {
109
+ "status": "running",
110
+ "service": "zero123plus-inference",
111
+ "model": MODEL_ID,
112
+ "output": "6-view grid + cropped views 1, 3, 5, 6 for LGM",
113
+ }
114
+
115
+
116
+ @app.get("/health")
117
+ def health():
118
+ return {
119
+ "status": "ok",
120
+ "cuda_available": torch.cuda.is_available(),
121
+ "cuda_device": torch.cuda.get_device_name(0)
122
+ if torch.cuda.is_available()
123
+ else None,
124
+ }
125
+
126
+
127
+ @app.post("/generate")
128
+ async def generate(file: UploadFile = File(...), steps: int = 75):
129
+ try:
130
+ pipe = load_pipeline()
131
+
132
+ contents = await file.read()
133
+ input_image = Image.open(io.BytesIO(contents)).convert("RGB")
134
+
135
+ with torch.inference_mode():
136
+ result = pipe(input_image, num_inference_steps=steps).images[0]
137
+
138
+ cropped_views = crop_selected_views(result)
139
+
140
+ zip_buffer = io.BytesIO()
141
+
142
+ with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
143
+ zf.writestr("multiview_grid.png", image_to_bytes(result))
144
+
145
+ for filename, image in cropped_views.items():
146
+ zf.writestr(filename, image_to_bytes(image))
147
+
148
+ zip_buffer.seek(0)
149
+
150
+ return StreamingResponse(
151
+ zip_buffer,
152
+ media_type="application/zip",
153
+ headers={
154
+ "Content-Disposition": "attachment; filename=zero123plus_outputs.zip"
155
+ },
156
+ )
157
+
158
+ except Exception as e:
159
+ return JSONResponse(
160
+ status_code=500,
161
+ content={
162
+ "error": str(e),
163
+ "message": "Zero123++ generation failed. Check Space logs for details.",
164
+ },
165
+ )