Spaces:
Runtime error
Runtime error
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from torchvision import transforms | |
| import torch.nn.functional as F | |
| from .dsine.dsine import DSINE | |
| from .dsine import utils as dsine_utils | |
| class NormalDetector: | |
| def __init__(self, model_path): | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.model = DSINE() | |
| self.model = dsine_utils.load_checkpoint(model_path, self.model) | |
| self.normalize = transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] | |
| ) | |
| self.fov = 60 | |
| def __call__(self, image): | |
| self.model.to(self.device) | |
| self.model.pixel_coords = self.model.pixel_coords.to(self.device) | |
| img = np.array(image).astype(np.float32) / 255.0 | |
| img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).to(self.device) | |
| _, _, orig_H, orig_W = img.shape | |
| l, r, t, b = dsine_utils.pad_input(orig_H, orig_W) | |
| img = F.pad(img, (l, r, t, b), mode="constant", value=0.0) | |
| img = self.normalize(img) | |
| intrinsics = dsine_utils.get_intrins_from_fov( | |
| new_fov=self.fov, H=orig_H, W=orig_W, device=self.device | |
| ).unsqueeze(0) | |
| intrinsics[:, 0, 2] += l | |
| intrinsics[:, 1, 2] += t | |
| pred_norm = self.model(img, intrins=intrinsics)[-1] | |
| pred_norm = pred_norm[:, :, t : t + orig_H, l : l + orig_W] | |
| pred_norm_np = ( | |
| pred_norm.cpu().detach().numpy()[0, :, :, :].transpose(1, 2, 0) | |
| ) # (H, W, 3) | |
| pred_norm_np = ((pred_norm_np + 1.0) / 2.0 * 255.0).astype(np.uint8) | |
| normal_img = Image.fromarray(pred_norm_np).resize((orig_W, orig_H)) | |
| self.model.to("cpu") | |
| self.model.pixel_coords = self.model.pixel_coords.to("cpu") | |
| return normal_img | |
| if __name__ == "__main__": | |
| from diffusers.utils import load_image | |
| image = load_image( | |
| "https://qhstaticssl.kujiale.com/image/jpeg/1716177580588/9AAA49344B9CE33512C4EBD0A287495F.jpg" | |
| ) | |
| image = np.asarray(image) | |
| normal_detector = NormalDetector( | |
| model_path="/juicefs/training/models/open_source/dsine/dsine.pt", | |
| efficientnet_path="/juicefs/training/models/open_source/dsine/tf_efficientnet_b5_ap-9e82fae8.pth", | |
| ) | |
| normal_image = normal_detector(image) | |
| normal_image.save("normal_image.jpg") | |