harshabasavarajbeth's picture
Create app.py
87a8857 verified
import io
import zipfile
from typing import Dict
import torch
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from PIL import Image
from rembg import remove
from diffusers import DiffusionPipeline
app = FastAPI(title="Zero123++ Inference API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
MODEL_ID = "sudo-ai/zero123plus-v1.2"
CUSTOM_PIPELINE = "sudo-ai/zero123plus-pipeline"
pipeline = None
def load_pipeline():
global pipeline
if pipeline is not None:
return pipeline
if not torch.cuda.is_available():
raise RuntimeError(
"CUDA GPU is not available. Please enable GPU hardware on the Hugging Face Space."
)
pipe = DiffusionPipeline.from_pretrained(
MODEL_ID,
custom_pipeline=CUSTOM_PIPELINE,
torch_dtype=torch.float16,
trust_remote_code=True,
)
pipe.to("cuda")
pipe.enable_attention_slicing()
pipeline = pipe
return pipeline
def image_to_bytes(image: Image.Image, fmt: str = "PNG") -> bytes:
buffer = io.BytesIO()
image.save(buffer, format=fmt)
buffer.seek(0)
return buffer.getvalue()
def crop_selected_views(grid: Image.Image) -> Dict[str, Image.Image]:
"""
Zero123++ output is expected as a 2-column x 3-row grid.
We keep views 1, 3, 5, and 6 for LGM.
"""
grid = grid.convert("RGB")
w, h = grid.size
cols, rows = 2, 3
tile_w, tile_h = w // cols, h // rows
selected = {
1: "front_right",
3: "right",
5: "back_right",
6: "back_left",
}
outputs = {}
for idx_1based, name in selected.items():
row = (idx_1based - 1) // cols
col = (idx_1based - 1) % cols
box = (
col * tile_w,
row * tile_h,
(col + 1) * tile_w,
(row + 1) * tile_h,
)
tile = grid.crop(box).resize((256, 256), Image.LANCZOS)
# Remove background and paste on white
tile_rgba = remove(tile.convert("RGBA"))
white_bg = Image.new("RGBA", tile_rgba.size, (255, 255, 255, 255))
white_bg.paste(tile_rgba, mask=tile_rgba.split()[3])
outputs[f"view_{idx_1based}_{name}.png"] = white_bg.convert("RGB")
return outputs
@app.get("/")
def root():
return {
"status": "running",
"service": "zero123plus-inference",
"model": MODEL_ID,
"output": "6-view grid + cropped views 1, 3, 5, 6 for LGM",
}
@app.get("/health")
def health():
return {
"status": "ok",
"cuda_available": torch.cuda.is_available(),
"cuda_device": torch.cuda.get_device_name(0)
if torch.cuda.is_available()
else None,
}
@app.post("/generate")
async def generate(file: UploadFile = File(...), steps: int = 75):
try:
pipe = load_pipeline()
contents = await file.read()
input_image = Image.open(io.BytesIO(contents)).convert("RGB")
with torch.inference_mode():
result = pipe(input_image, num_inference_steps=steps).images[0]
cropped_views = crop_selected_views(result)
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
zf.writestr("multiview_grid.png", image_to_bytes(result))
for filename, image in cropped_views.items():
zf.writestr(filename, image_to_bytes(image))
zip_buffer.seek(0)
return StreamingResponse(
zip_buffer,
media_type="application/zip",
headers={
"Content-Disposition": "attachment; filename=zero123plus_outputs.zip"
},
)
except Exception as e:
return JSONResponse(
status_code=500,
content={
"error": str(e),
"message": "Zero123++ generation failed. Check Space logs for details.",
},
)