import io import zipfile from typing import Dict import torch from fastapi import FastAPI, File, UploadFile from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse from PIL import Image from rembg import remove from diffusers import DiffusionPipeline app = FastAPI(title="Zero123++ Inference API") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) MODEL_ID = "sudo-ai/zero123plus-v1.2" CUSTOM_PIPELINE = "sudo-ai/zero123plus-pipeline" pipeline = None def load_pipeline(): global pipeline if pipeline is not None: return pipeline if not torch.cuda.is_available(): raise RuntimeError( "CUDA GPU is not available. Please enable GPU hardware on the Hugging Face Space." ) pipe = DiffusionPipeline.from_pretrained( MODEL_ID, custom_pipeline=CUSTOM_PIPELINE, torch_dtype=torch.float16, trust_remote_code=True, ) pipe.to("cuda") pipe.enable_attention_slicing() pipeline = pipe return pipeline def image_to_bytes(image: Image.Image, fmt: str = "PNG") -> bytes: buffer = io.BytesIO() image.save(buffer, format=fmt) buffer.seek(0) return buffer.getvalue() def crop_selected_views(grid: Image.Image) -> Dict[str, Image.Image]: """ Zero123++ output is expected as a 2-column x 3-row grid. We keep views 1, 3, 5, and 6 for LGM. """ grid = grid.convert("RGB") w, h = grid.size cols, rows = 2, 3 tile_w, tile_h = w // cols, h // rows selected = { 1: "front_right", 3: "right", 5: "back_right", 6: "back_left", } outputs = {} for idx_1based, name in selected.items(): row = (idx_1based - 1) // cols col = (idx_1based - 1) % cols box = ( col * tile_w, row * tile_h, (col + 1) * tile_w, (row + 1) * tile_h, ) tile = grid.crop(box).resize((256, 256), Image.LANCZOS) # Remove background and paste on white tile_rgba = remove(tile.convert("RGBA")) white_bg = Image.new("RGBA", tile_rgba.size, (255, 255, 255, 255)) white_bg.paste(tile_rgba, mask=tile_rgba.split()[3]) outputs[f"view_{idx_1based}_{name}.png"] = white_bg.convert("RGB") return outputs @app.get("/") def root(): return { "status": "running", "service": "zero123plus-inference", "model": MODEL_ID, "output": "6-view grid + cropped views 1, 3, 5, 6 for LGM", } @app.get("/health") def health(): return { "status": "ok", "cuda_available": torch.cuda.is_available(), "cuda_device": torch.cuda.get_device_name(0) if torch.cuda.is_available() else None, } @app.post("/generate") async def generate(file: UploadFile = File(...), steps: int = 75): try: pipe = load_pipeline() contents = await file.read() input_image = Image.open(io.BytesIO(contents)).convert("RGB") with torch.inference_mode(): result = pipe(input_image, num_inference_steps=steps).images[0] cropped_views = crop_selected_views(result) zip_buffer = io.BytesIO() with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf: zf.writestr("multiview_grid.png", image_to_bytes(result)) for filename, image in cropped_views.items(): zf.writestr(filename, image_to_bytes(image)) zip_buffer.seek(0) return StreamingResponse( zip_buffer, media_type="application/zip", headers={ "Content-Disposition": "attachment; filename=zero123plus_outputs.zip" }, ) except Exception as e: return JSONResponse( status_code=500, content={ "error": str(e), "message": "Zero123++ generation failed. Check Space logs for details.", }, )