Spaces:
Sleeping
Sleeping
| import io | |
| import zipfile | |
| from typing import Dict | |
| import torch | |
| from fastapi import FastAPI, File, UploadFile | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse, StreamingResponse | |
| from PIL import Image | |
| from rembg import remove | |
| from diffusers import DiffusionPipeline | |
| app = FastAPI(title="Zero123++ Inference API") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| MODEL_ID = "sudo-ai/zero123plus-v1.2" | |
| CUSTOM_PIPELINE = "sudo-ai/zero123plus-pipeline" | |
| pipeline = None | |
| def load_pipeline(): | |
| global pipeline | |
| if pipeline is not None: | |
| return pipeline | |
| if not torch.cuda.is_available(): | |
| raise RuntimeError( | |
| "CUDA GPU is not available. Please enable GPU hardware on the Hugging Face Space." | |
| ) | |
| pipe = DiffusionPipeline.from_pretrained( | |
| MODEL_ID, | |
| custom_pipeline=CUSTOM_PIPELINE, | |
| torch_dtype=torch.float16, | |
| trust_remote_code=True, | |
| ) | |
| pipe.to("cuda") | |
| pipe.enable_attention_slicing() | |
| pipeline = pipe | |
| return pipeline | |
| def image_to_bytes(image: Image.Image, fmt: str = "PNG") -> bytes: | |
| buffer = io.BytesIO() | |
| image.save(buffer, format=fmt) | |
| buffer.seek(0) | |
| return buffer.getvalue() | |
| def crop_selected_views(grid: Image.Image) -> Dict[str, Image.Image]: | |
| """ | |
| Zero123++ output is expected as a 2-column x 3-row grid. | |
| We keep views 1, 3, 5, and 6 for LGM. | |
| """ | |
| grid = grid.convert("RGB") | |
| w, h = grid.size | |
| cols, rows = 2, 3 | |
| tile_w, tile_h = w // cols, h // rows | |
| selected = { | |
| 1: "front_right", | |
| 3: "right", | |
| 5: "back_right", | |
| 6: "back_left", | |
| } | |
| outputs = {} | |
| for idx_1based, name in selected.items(): | |
| row = (idx_1based - 1) // cols | |
| col = (idx_1based - 1) % cols | |
| box = ( | |
| col * tile_w, | |
| row * tile_h, | |
| (col + 1) * tile_w, | |
| (row + 1) * tile_h, | |
| ) | |
| tile = grid.crop(box).resize((256, 256), Image.LANCZOS) | |
| # Remove background and paste on white | |
| tile_rgba = remove(tile.convert("RGBA")) | |
| white_bg = Image.new("RGBA", tile_rgba.size, (255, 255, 255, 255)) | |
| white_bg.paste(tile_rgba, mask=tile_rgba.split()[3]) | |
| outputs[f"view_{idx_1based}_{name}.png"] = white_bg.convert("RGB") | |
| return outputs | |
| def root(): | |
| return { | |
| "status": "running", | |
| "service": "zero123plus-inference", | |
| "model": MODEL_ID, | |
| "output": "6-view grid + cropped views 1, 3, 5, 6 for LGM", | |
| } | |
| def health(): | |
| return { | |
| "status": "ok", | |
| "cuda_available": torch.cuda.is_available(), | |
| "cuda_device": torch.cuda.get_device_name(0) | |
| if torch.cuda.is_available() | |
| else None, | |
| } | |
| async def generate(file: UploadFile = File(...), steps: int = 75): | |
| try: | |
| pipe = load_pipeline() | |
| contents = await file.read() | |
| input_image = Image.open(io.BytesIO(contents)).convert("RGB") | |
| with torch.inference_mode(): | |
| result = pipe(input_image, num_inference_steps=steps).images[0] | |
| cropped_views = crop_selected_views(result) | |
| zip_buffer = io.BytesIO() | |
| with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf: | |
| zf.writestr("multiview_grid.png", image_to_bytes(result)) | |
| for filename, image in cropped_views.items(): | |
| zf.writestr(filename, image_to_bytes(image)) | |
| zip_buffer.seek(0) | |
| return StreamingResponse( | |
| zip_buffer, | |
| media_type="application/zip", | |
| headers={ | |
| "Content-Disposition": "attachment; filename=zero123plus_outputs.zip" | |
| }, | |
| ) | |
| except Exception as e: | |
| return JSONResponse( | |
| status_code=500, | |
| content={ | |
| "error": str(e), | |
| "message": "Zero123++ generation failed. Check Space logs for details.", | |
| }, | |
| ) |