Spaces:
Runtime error
Runtime error
| import huggingface_hub | |
| huggingface_hub.cached_download = huggingface_hub.hf_hub_download | |
| from huggingface_hub import hf_hub_download | |
| import os | |
| import io | |
| import base64 | |
| import tempfile | |
| import numpy as np | |
| import torch | |
| from fastapi import FastAPI, File, UploadFile, Form, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from PIL import Image | |
| from torchvision.transforms import v2 | |
| from pytorch_lightning import seed_everything | |
| from omegaconf import OmegaConf | |
| from einops import rearrange | |
| from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler | |
| # Monkey-patch for diffusers<=0.19.3 which still does | |
| # from huggingface_hub import cached_download | |
| # New HF-Hub versions (>=0.14.0) removed cached_download, so we alias it. | |
| # your util functions & model loaders | |
| from src.utils.train_util import instantiate_from_config | |
| from src.utils.camera_util import ( | |
| FOV_to_intrinsics, | |
| get_zero123plus_input_cameras, | |
| get_circular_camera_poses, | |
| ) | |
| from src.utils.mesh_util import save_obj, save_glb | |
| from src.utils.infer_util import remove_background, resize_foreground | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 1) CONFIGURATION & MODEL LOADING | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Load our YAML config | |
| config_path = 'configs/instant-mesh-large.yaml' | |
| config = OmegaConf.load(config_path) | |
| model_config = config.model_config | |
| infer_config = config.infer_config | |
| IS_FLEXICUBES = os.path.basename(config_path).startswith('instant-mesh') | |
| # pick device | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| if device.type == 'cpu': | |
| print("β οΈ No CUDA GPU detected. Falling back to CPU (this will be very slow!)") | |
| # choose torch dtype: float16 on GPU, float32 on CPU | |
| torch_dtype = torch.float16 if device.type == 'cuda' else torch.float32 | |
| # βββ Load diffusion (Zero123) pipeline βββ | |
| print("Loading diffusion model β¦") | |
| pipeline = DiffusionPipeline.from_pretrained( | |
| "sudo-ai/zero123plus-v1.2", | |
| custom_pipeline="zero123plus", | |
| torch_dtype=torch_dtype, | |
| ) | |
| pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config( | |
| pipeline.scheduler.config, timestep_spacing='trailing' | |
| ) | |
| # patch UNet to white-background version | |
| unet_ckpt = hf_hub_download( | |
| repo_id="TencentARC/InstantMesh", | |
| filename="diffusion_pytorch_model.bin", | |
| repo_type="model", | |
| ) | |
| sd = torch.load(unet_ckpt, map_location='cpu') | |
| pipeline.unet.load_state_dict(sd, strict=True) | |
| pipeline = pipeline.to(device) | |
| # βββ Load reconstruction (InstantMesh) model βββ | |
| print("Loading reconstruction model β¦") | |
| model_ckpt = hf_hub_download( | |
| repo_id="TencentARC/InstantMesh", | |
| filename="instant_mesh_large.ckpt", | |
| repo_type="model", | |
| ) | |
| model = instantiate_from_config(model_config) | |
| full_sd = torch.load(model_ckpt, map_location='cpu')['state_dict'] | |
| # strip the "lrm_generator." prefix & unwanted keys | |
| sd = { | |
| k[len("lrm_generator.") :]: v | |
| for k, v in full_sd.items() | |
| if k.startswith("lrm_generator.") and "source_camera" not in k | |
| } | |
| model.load_state_dict(sd, strict=True) | |
| model = model.to(device).eval() | |
| print("Models loaded β ") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 2) HELPERS & INFERENCE LOGIC | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False): | |
| c2ws = get_circular_camera_poses(M=M, radius=radius, elevation=elevation) | |
| if is_flexicubes: | |
| cams = torch.linalg.inv(c2ws) | |
| return cams.unsqueeze(0).repeat(batch_size, 1, 1, 1) | |
| else: | |
| ext = c2ws.flatten(-2) | |
| intr = FOV_to_intrinsics(50.0).unsqueeze(0).repeat(M, 1, 1).float().flatten(-2) | |
| cams = torch.cat([ext, intr], dim=-1) | |
| return cams.unsqueeze(0).repeat(batch_size, 1, 1) | |
| def preprocess(input_image: Image.Image, do_remove_background: bool): | |
| rembg_sess = None | |
| if do_remove_background: | |
| rembg_sess = __import__("rembg").new_session() | |
| input_image = remove_background(input_image, rembg_sess) | |
| input_image = resize_foreground(input_image, 0.85) | |
| return input_image | |
| def generate_mvs( | |
| input_image: Image.Image, sample_steps: int, sample_seed: int | |
| ) -> tuple[Image.Image, Image.Image]: | |
| """Return (raw_multi_view, preview_image).""" | |
| seed_everything(sample_seed) | |
| out = pipeline(input_image, num_inference_steps=sample_steps) | |
| mv = out.images[0] # PIL, shape (HΓWΓ3) | |
| # create a tiled preview | |
| arr = np.asarray(mv, dtype=np.uint8) | |
| t = torch.from_numpy(arr) | |
| t = rearrange(t, "(n h) (m w) c -> (n m) h w c", n=3, m=2) | |
| t = rearrange(t, "(n m) h w c -> (n h) (m w) c", n=2, m=3) | |
| preview = Image.fromarray(t.numpy()) | |
| return mv, preview | |
| def make3d( | |
| mv: Image.Image, | |
| ) -> tuple[str, str]: | |
| """Return (path_to_obj, path_to_glb).""" | |
| # initialize flexicubes if needed | |
| if IS_FLEXICUBES: | |
| model.init_flexicubes_geometry(device, use_renderer=False) | |
| # normalize & reshape | |
| imgs = np.asarray(mv, dtype=np.float32) / 255.0 | |
| t = torch.from_numpy(imgs).permute(2, 0, 1).contiguous().float() | |
| t = rearrange(t, "c (n h) (m w) -> (n m) c h w", n=3, m=2) | |
| cam_in = get_zero123plus_input_cameras(1, radius=4.0).to(device) | |
| cam_render = get_render_cameras( | |
| 1, radius=2.5, is_flexicubes=IS_FLEXICUBES | |
| ).to(device) | |
| t = t.unsqueeze(0).to(device) | |
| t = v2.functional.resize(t, (320, 320), interpolation=3, antialias=True).clamp(0, 1) | |
| # temp file names | |
| obj_f = tempfile.NamedTemporaryFile(suffix=".obj", delete=False).name | |
| base = os.path.splitext(obj_f)[0] | |
| glb_f = base + ".glb" | |
| with torch.no_grad(): | |
| planes = model.forward_planes(t, cam_in) | |
| mesh = model.extract_mesh( | |
| planes, use_texture_map=False, **infer_config | |
| ) | |
| verts, faces, colors = mesh | |
| verts = verts[:, [1, 2, 0]] | |
| save_obj(verts, faces, colors, obj_f) | |
| save_glb(verts, faces, colors, glb_f) | |
| return obj_f, glb_f | |
| def _pil_to_b64(img: Image.Image, fmt: str = "PNG") -> str: | |
| buf = io.BytesIO() | |
| img.save(buf, fmt) | |
| return base64.b64encode(buf.getvalue()).decode() | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 3) FASTAPI APP | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| app = FastAPI(title="InstantMesh FastAPI Demo") | |
| async def infer( | |
| file: UploadFile = File(...), | |
| remove_background: bool = Form(True), | |
| sample_steps: int = Form(75, ge=1, le=100), | |
| sample_seed: int = Form(42), | |
| ): | |
| # 1) load the RGBA image | |
| data = await file.read() | |
| try: | |
| img = Image.open(io.BytesIO(data)).convert("RGBA") | |
| except Exception: | |
| raise HTTPException(400, detail="Invalid image") | |
| # 2) run through pipeline | |
| proc = preprocess(img, remove_background) | |
| mv_raw, mv_preview = generate_mvs(proc, sample_steps, sample_seed) | |
| obj_path, glb_path = make3d(mv_raw) | |
| # 3) read back the mesh bytes | |
| with open(obj_path, "rb") as f: | |
| obj_b = f.read() | |
| with open(glb_path, "rb") as f: | |
| glb_b = f.read() | |
| return JSONResponse( | |
| { | |
| "preview_png": _pil_to_b64(mv_preview), | |
| "multi_views_png": _pil_to_b64(mv_raw), | |
| "obj_data_b64": base64.b64encode(obj_b).decode(), | |
| "glb_data_b64": base64.b64encode(glb_b).decode(), | |
| } | |
| ) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run( | |
| "main:app", | |
| host="0.0.0.0", | |
| port=int(os.environ.get("PORT", 8000)), | |
| reload=True, | |
| ) |