Spaces:
Sleeping
Sleeping
File size: 4,120 Bytes
87a8857 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 | import io
import zipfile
from typing import Dict
import torch
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from PIL import Image
from rembg import remove
from diffusers import DiffusionPipeline
app = FastAPI(title="Zero123++ Inference API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
MODEL_ID = "sudo-ai/zero123plus-v1.2"
CUSTOM_PIPELINE = "sudo-ai/zero123plus-pipeline"
pipeline = None
def load_pipeline():
global pipeline
if pipeline is not None:
return pipeline
if not torch.cuda.is_available():
raise RuntimeError(
"CUDA GPU is not available. Please enable GPU hardware on the Hugging Face Space."
)
pipe = DiffusionPipeline.from_pretrained(
MODEL_ID,
custom_pipeline=CUSTOM_PIPELINE,
torch_dtype=torch.float16,
trust_remote_code=True,
)
pipe.to("cuda")
pipe.enable_attention_slicing()
pipeline = pipe
return pipeline
def image_to_bytes(image: Image.Image, fmt: str = "PNG") -> bytes:
buffer = io.BytesIO()
image.save(buffer, format=fmt)
buffer.seek(0)
return buffer.getvalue()
def crop_selected_views(grid: Image.Image) -> Dict[str, Image.Image]:
"""
Zero123++ output is expected as a 2-column x 3-row grid.
We keep views 1, 3, 5, and 6 for LGM.
"""
grid = grid.convert("RGB")
w, h = grid.size
cols, rows = 2, 3
tile_w, tile_h = w // cols, h // rows
selected = {
1: "front_right",
3: "right",
5: "back_right",
6: "back_left",
}
outputs = {}
for idx_1based, name in selected.items():
row = (idx_1based - 1) // cols
col = (idx_1based - 1) % cols
box = (
col * tile_w,
row * tile_h,
(col + 1) * tile_w,
(row + 1) * tile_h,
)
tile = grid.crop(box).resize((256, 256), Image.LANCZOS)
# Remove background and paste on white
tile_rgba = remove(tile.convert("RGBA"))
white_bg = Image.new("RGBA", tile_rgba.size, (255, 255, 255, 255))
white_bg.paste(tile_rgba, mask=tile_rgba.split()[3])
outputs[f"view_{idx_1based}_{name}.png"] = white_bg.convert("RGB")
return outputs
@app.get("/")
def root():
return {
"status": "running",
"service": "zero123plus-inference",
"model": MODEL_ID,
"output": "6-view grid + cropped views 1, 3, 5, 6 for LGM",
}
@app.get("/health")
def health():
return {
"status": "ok",
"cuda_available": torch.cuda.is_available(),
"cuda_device": torch.cuda.get_device_name(0)
if torch.cuda.is_available()
else None,
}
@app.post("/generate")
async def generate(file: UploadFile = File(...), steps: int = 75):
try:
pipe = load_pipeline()
contents = await file.read()
input_image = Image.open(io.BytesIO(contents)).convert("RGB")
with torch.inference_mode():
result = pipe(input_image, num_inference_steps=steps).images[0]
cropped_views = crop_selected_views(result)
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
zf.writestr("multiview_grid.png", image_to_bytes(result))
for filename, image in cropped_views.items():
zf.writestr(filename, image_to_bytes(image))
zip_buffer.seek(0)
return StreamingResponse(
zip_buffer,
media_type="application/zip",
headers={
"Content-Disposition": "attachment; filename=zero123plus_outputs.zip"
},
)
except Exception as e:
return JSONResponse(
status_code=500,
content={
"error": str(e),
"message": "Zero123++ generation failed. Check Space logs for details.",
},
) |