| | import os |
| | import cv2 |
| | import gradio as gr |
| | import numpy as np |
| | import spaces |
| | import torch |
| | import torch.nn.functional as F |
| | from gradio.themes.utils import sizes |
| | from PIL import Image |
| | from torchvision import transforms |
| | import tempfile |
| |
|
| | class Config: |
| | ASSETS_DIR = os.path.join(os.path.dirname(__file__), 'assets') |
| | CHECKPOINTS_DIR = os.path.join(ASSETS_DIR, "checkpoints") |
| | CHECKPOINTS = { |
| | "0.3b": "sapiens_0.3b_normal_render_people_epoch_66_torchscript.pt2", |
| | "0.6b": "sapiens_0.6b_normal_render_people_epoch_200_torchscript.pt2", |
| | "1b": "sapiens_1b_normal_render_people_epoch_115_torchscript.pt2", |
| | "2b": "sapiens_2b_normal_render_people_epoch_70_torchscript.pt2", |
| | } |
| | SEG_CHECKPOINTS = { |
| | "fg-bg-1b (recommended)": "sapiens_1b_seg_foreground_epoch_8_torchscript.pt2", |
| | "no-bg-removal": None, |
| | "part-seg-1b": "sapiens_1b_goliath_best_goliath_mIoU_7994_epoch_151_torchscript.pt2", |
| | } |
| |
|
| | class ModelManager: |
| | @staticmethod |
| | def load_model(checkpoint_name: str): |
| | if checkpoint_name is None: |
| | return None |
| | checkpoint_path = os.path.join(Config.CHECKPOINTS_DIR, checkpoint_name) |
| | model = torch.jit.load(checkpoint_path) |
| | model.eval() |
| | model.to("cuda") |
| | return model |
| |
|
| | @staticmethod |
| | @torch.inference_mode() |
| | def run_model(model, input_tensor, height, width): |
| | output = model(input_tensor) |
| | return F.interpolate(output, size=(height, width), mode="bilinear", align_corners=False) |
| |
|
| | class ImageProcessor: |
| | def __init__(self): |
| | self.transform_fn = transforms.Compose([ |
| | transforms.Resize((1024, 768)), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=[123.5/255, 116.5/255, 103.5/255], std=[58.5/255, 57.0/255, 57.5/255]), |
| | ]) |
| |
|
| | @spaces.GPU |
| | def process_image(self, image: Image.Image, normal_model_name: str, seg_model_name: str): |
| | |
| | normal_model = ModelManager.load_model(Config.CHECKPOINTS[normal_model_name]) |
| | input_tensor = self.transform_fn(image).unsqueeze(0).to("cuda") |
| |
|
| | |
| | normal_output = ModelManager.run_model(normal_model, input_tensor, image.height, image.width) |
| | normal_map = normal_output.squeeze().cpu().numpy().transpose(1, 2, 0) |
| |
|
| | |
| | normal_map_vis = normal_map.copy() |
| |
|
| | |
| | if seg_model_name != "no-bg-removal": |
| | seg_model = ModelManager.load_model(Config.SEG_CHECKPOINTS[seg_model_name]) |
| | seg_output = ModelManager.run_model(seg_model, input_tensor, image.height, image.width) |
| | seg_mask = (seg_output.argmax(dim=1) > 0).float().cpu().numpy()[0] |
| |
|
| | |
| | normal_map[seg_mask == 0] = np.nan |
| | normal_map_vis[seg_mask == 0] = -1 |
| |
|
| | |
| | normal_map_vis = self.visualize_normal_map(normal_map_vis) |
| |
|
| | |
| | npy_path = tempfile.mktemp(suffix='.npy') |
| | np.save(npy_path, normal_map) |
| |
|
| | return Image.fromarray(normal_map_vis), npy_path |
| |
|
| | @staticmethod |
| | def visualize_normal_map(normal_map): |
| | normal_map_norm = np.linalg.norm(normal_map, axis=-1, keepdims=True) |
| | normal_map_normalized = normal_map / (normal_map_norm + 1e-5) |
| | normal_map_vis = ((normal_map_normalized + 1) / 2 * 255).astype(np.uint8) |
| | return normal_map_vis |
| |
|
| | class GradioInterface: |
| | def __init__(self): |
| | self.image_processor = ImageProcessor() |
| |
|
| | def create_interface(self): |
| | app_styles = """ |
| | <style> |
| | /* Global Styles */ |
| | body, #root { |
| | font-family: Helvetica, Arial, sans-serif; |
| | background-color: #1a1a1a; |
| | color: #fafafa; |
| | } |
| | |
| | /* Header Styles */ |
| | .app-header { |
| | background: linear-gradient(45deg, #1a1a1a 0%, #333333 100%); |
| | padding: 24px; |
| | border-radius: 8px; |
| | margin-bottom: 24px; |
| | text-align: center; |
| | } |
| | |
| | .app-title { |
| | font-size: 48px; |
| | margin: 0; |
| | color: #fafafa; |
| | } |
| | |
| | .app-subtitle { |
| | font-size: 24px; |
| | margin: 8px 0 16px; |
| | color: #fafafa; |
| | } |
| | |
| | .app-description { |
| | font-size: 16px; |
| | line-height: 1.6; |
| | opacity: 0.8; |
| | margin-bottom: 24px; |
| | } |
| | |
| | /* Button Styles */ |
| | .publication-links { |
| | display: flex; |
| | justify-content: center; |
| | flex-wrap: wrap; |
| | gap: 8px; |
| | margin-bottom: 16px; |
| | } |
| | |
| | .publication-link { |
| | display: inline-flex; |
| | align-items: center; |
| | padding: 8px 16px; |
| | background-color: #333; |
| | color: #fff !important; |
| | text-decoration: none !important; |
| | border-radius: 20px; |
| | font-size: 14px; |
| | transition: background-color 0.3s; |
| | } |
| | |
| | .publication-link:hover { |
| | background-color: #555; |
| | } |
| | |
| | .publication-link i { |
| | margin-right: 8px; |
| | } |
| | |
| | /* Content Styles */ |
| | .content-container { |
| | background-color: #2a2a2a; |
| | border-radius: 8px; |
| | padding: 24px; |
| | margin-bottom: 24px; |
| | } |
| | |
| | /* Image Styles */ |
| | .image-preview img { |
| | max-width: 100%; |
| | max-height: 512px; |
| | margin: 0 auto; |
| | border-radius: 4px; |
| | display: block; |
| | } |
| | |
| | /* Control Styles */ |
| | .control-panel { |
| | background-color: #333; |
| | padding: 16px; |
| | border-radius: 8px; |
| | margin-top: 16px; |
| | } |
| | |
| | /* Gradio Component Overrides */ |
| | .gr-button { |
| | background-color: #4a4a4a; |
| | color: #fff; |
| | border: none; |
| | border-radius: 4px; |
| | padding: 8px 16px; |
| | cursor: pointer; |
| | transition: background-color 0.3s; |
| | } |
| | |
| | .gr-button:hover { |
| | background-color: #5a5a5a; |
| | } |
| | |
| | .gr-input, .gr-dropdown { |
| | background-color: #3a3a3a; |
| | color: #fff; |
| | border: 1px solid #4a4a4a; |
| | border-radius: 4px; |
| | padding: 8px; |
| | } |
| | |
| | .gr-form { |
| | background-color: transparent; |
| | } |
| | |
| | .gr-panel { |
| | border: none; |
| | background-color: transparent; |
| | } |
| | |
| | /* Override any conflicting styles from Bulma */ |
| | .button.is-normal.is-rounded.is-dark { |
| | color: #fff !important; |
| | text-decoration: none !important; |
| | } |
| | </style> |
| | """ |
| |
|
| | header_html = f""" |
| | <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bulma@0.9.3/css/bulma.min.css"> |
| | <link rel="stylesheet" href="https://use.fontawesome.com/releases/v5.15.4/css/all.css"> |
| | {app_styles} |
| | <div class="app-header"> |
| | <h1 class="app-title">Sapiens: Normal Estimation</h1> |
| | <h2 class="app-subtitle">ECCV 2024 (Oral)</h2> |
| | <p class="app-description"> |
| | Meta presents Sapiens, foundation models for human tasks pretrained on 300 million human images. |
| | This demo showcases the finetuned normal estimation model. <br> |
| | Checkout other normal estimation baselines to compare: <a href="https://huggingface.co/spaces/Stable-X/normal-estimation-arena" style="color: #3273dc;">normal-estimation-arena</a> |
| | </p> |
| | <div class="publication-links"> |
| | <a href="https://arxiv.org/abs/2408.12569" class="publication-link"> |
| | <i class="fas fa-file-pdf"></i>arXiv |
| | </a> |
| | <a href="https://github.com/facebookresearch/sapiens" class="publication-link"> |
| | <i class="fab fa-github"></i>Code |
| | </a> |
| | <a href="https://about.meta.com/realitylabs/codecavatars/sapiens/" class="publication-link"> |
| | <i class="fas fa-globe"></i>Meta |
| | </a> |
| | <a href="https://rawalkhirodkar.github.io/sapiens" class="publication-link"> |
| | <i class="fas fa-chart-bar"></i>Results |
| | </a> |
| | </div> |
| | <div class="publication-links"> |
| | <a href="https://huggingface.co/spaces/facebook/sapiens_pose" class="publication-link"> |
| | <i class="fas fa-user"></i>Demo-Pose |
| | </a> |
| | <a href="https://huggingface.co/spaces/facebook/sapiens_seg" class="publication-link"> |
| | <i class="fas fa-puzzle-piece"></i>Demo-Seg |
| | </a> |
| | <a href="https://huggingface.co/spaces/facebook/sapiens_depth" class="publication-link"> |
| | <i class="fas fa-cube"></i>Demo-Depth |
| | </a> |
| | <a href="https://huggingface.co/spaces/facebook/sapiens_normal" class="publication-link"> |
| | <i class="fas fa-vector-square"></i>Demo-Normal |
| | </a> |
| | </div> |
| | </div> |
| | """ |
| |
|
| | def process_image(image, normal_model_name, seg_model_name): |
| | result, npy_path = self.image_processor.process_image(image, normal_model_name, seg_model_name) |
| | return result, npy_path |
| | |
| | js_func = """ |
| | function refresh() { |
| | const url = new URL(window.location); |
| | if (url.searchParams.get('__theme') !== 'dark') { |
| | url.searchParams.set('__theme', 'dark'); |
| | window.location.href = url.href; |
| | } |
| | } |
| | """ |
| |
|
| | with gr.Blocks(js=js_func, theme=gr.themes.Default()) as demo: |
| | gr.HTML(header_html) |
| | with gr.Row(elem_classes="content-container"): |
| | with gr.Column(): |
| | input_image = gr.Image(label="Input Image", type="pil", format="png", elem_classes="image-preview") |
| | with gr.Row(elem_classes="control-panel"): |
| | normal_model_name = gr.Dropdown( |
| | label="Normal Model Size", |
| | choices=list(Config.CHECKPOINTS.keys()), |
| | value="1b", |
| | ) |
| | seg_model_name = gr.Dropdown( |
| | label="Background Removal Model", |
| | choices=list(Config.SEG_CHECKPOINTS.keys()), |
| | value="fg-bg-1b (recommended)", |
| | ) |
| | example_model = gr.Examples( |
| | inputs=input_image, |
| | examples_per_page=14, |
| | examples=[ |
| | os.path.join(Config.ASSETS_DIR, "images", img) |
| | for img in os.listdir(os.path.join(Config.ASSETS_DIR, "images")) |
| | ], |
| | ) |
| | with gr.Column(): |
| | result_image = gr.Image(label="Normal Estimation Result", type="pil", elem_classes="image-preview") |
| | npy_output = gr.File(label="Output (.npy). Note: Background normal is NaN.") |
| | run_button = gr.Button("Run", elem_classes="gr-button") |
| |
|
| | run_button.click( |
| | fn=process_image, |
| | inputs=[input_image, normal_model_name, seg_model_name], |
| | outputs=[result_image, npy_output], |
| | ) |
| |
|
| | return demo |
| |
|
| | def main(): |
| | |
| | if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8: |
| | torch.backends.cuda.matmul.allow_tf32 = True |
| | torch.backends.cudnn.allow_tf32 = True |
| |
|
| | interface = GradioInterface() |
| | demo = interface.create_interface() |
| | demo.launch(share=False) |
| |
|
| | if __name__ == "__main__": |
| | main() |