carpedm20 commited on
Commit
f6e4f38
·
verified ·
1 Parent(s): 1879f62

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
37
+ *.pt2 filter=lfs diff=lfs merge=lfs -text
38
+ assets/palette.jpg filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Sapiens Segmentation
3
+ emoji: 🌍
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 4.42.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: cc-by-nc-4.0
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import gradio as gr
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torchvision import transforms
8
+ from PIL import Image
9
+ import tempfile
10
+
11
+ from gradio.themes.utils import sizes
12
+ from classes_and_palettes import GOLIATH_PALETTE, GOLIATH_CLASSES
13
+
14
+
15
+ # =========================================================
16
+ # Config
17
+ # =========================================================
18
+
19
+ class Config:
20
+ ASSETS_DIR = os.path.join(os.path.dirname(__file__), "assets")
21
+ CHECKPOINTS_DIR = os.path.join(ASSETS_DIR, "checkpoints")
22
+ CHECKPOINTS = {
23
+ "0.3b": "sapiens_0.3b_goliath_best_goliath_mIoU_7673_epoch_194_torchscript.pt2",
24
+ "0.6b": "sapiens_0.6b_goliath_best_goliath_mIoU_7777_epoch_178_torchscript.pt2",
25
+ "1b": "sapiens_1b_goliath_best_goliath_mIoU_7994_epoch_151_torchscript.pt2",
26
+ }
27
+
28
+
29
+ # =========================================================
30
+ # Model
31
+ # =========================================================
32
+
33
+ class ModelManager:
34
+ _cache = {}
35
+
36
+ @staticmethod
37
+ def load_model(checkpoint_name: str):
38
+ if checkpoint_name in ModelManager._cache:
39
+ return ModelManager._cache[checkpoint_name]
40
+
41
+ checkpoint_path = os.path.join(
42
+ Config.CHECKPOINTS_DIR,
43
+ Config.CHECKPOINTS[checkpoint_name],
44
+ )
45
+ model = torch.jit.load(checkpoint_path)
46
+ model.eval()
47
+ model.to("cuda")
48
+ ModelManager._cache[checkpoint_name] = model
49
+ return model
50
+
51
+ @staticmethod
52
+ @torch.inference_mode()
53
+ def run_model(model, input_tensor, height, width):
54
+ output = model(input_tensor)
55
+ output = F.interpolate(
56
+ output,
57
+ size=(height, width),
58
+ mode="bilinear",
59
+ align_corners=False,
60
+ )
61
+ _, preds = torch.max(output, 1)
62
+ return preds
63
+
64
+
65
+ # =========================================================
66
+ # Image Processing
67
+ # =========================================================
68
+
69
+ class ImageProcessor:
70
+ def __init__(self):
71
+ self.transform_fn = transforms.Compose([
72
+ transforms.Resize((1024, 768)),
73
+ transforms.ToTensor(),
74
+ transforms.Normalize(
75
+ mean=[123.5 / 255, 116.5 / 255, 103.5 / 255],
76
+ std=[58.5 / 255, 57.0 / 255, 57.5 / 255],
77
+ ),
78
+ ])
79
+
80
+ def process_image(self, image: Image.Image, model_name: str):
81
+ model = ModelManager.load_model(model_name)
82
+ input_tensor = self.transform_fn(image).unsqueeze(0).to("cuda")
83
+
84
+ preds = ModelManager.run_model(
85
+ model,
86
+ input_tensor,
87
+ image.height,
88
+ image.width,
89
+ )
90
+
91
+ mask = preds.squeeze(0).cpu().numpy()
92
+ blended_image = self.visualize_pred_with_overlay(image, mask)
93
+
94
+ npy_path = tempfile.mktemp(suffix=".npy")
95
+ np.save(npy_path, mask)
96
+
97
+ return blended_image, npy_path
98
+
99
+ @staticmethod
100
+ def visualize_pred_with_overlay(img, sem_seg, alpha=0.5):
101
+ img_np = np.array(img.convert("RGB"))
102
+ sem_seg = np.array(sem_seg)
103
+
104
+ num_classes = len(GOLIATH_CLASSES)
105
+ ids = np.unique(sem_seg)
106
+ ids = ids[ids < num_classes]
107
+
108
+ overlay = np.zeros((*sem_seg.shape, 3), dtype=np.uint8)
109
+ for label in ids:
110
+ overlay[sem_seg == label] = GOLIATH_PALETTE[label]
111
+
112
+ blended = np.uint8(img_np * (1 - alpha) + overlay * alpha)
113
+ return Image.fromarray(blended)
114
+
115
+
116
+ # =========================================================
117
+ # UI
118
+ # =========================================================
119
+
120
+ class GradioInterface:
121
+ def __init__(self):
122
+ self.image_processor = ImageProcessor()
123
+
124
+ def create_interface(self):
125
+ # -------------------------
126
+ # Theme (modern Gradio)
127
+ # -------------------------
128
+ theme = gr.themes.Soft(
129
+ primary_hue="neutral",
130
+ secondary_hue="slate",
131
+ neutral_hue="zinc",
132
+ radius_size=sizes.radius_md,
133
+ text_size=sizes.text_md,
134
+ ).set(
135
+ body_background_fill="#1a1a1a",
136
+ body_text_color="#fafafa",
137
+ block_background_fill="#2a2a2a",
138
+ block_border_color="#333333",
139
+ button_primary_background_fill="#4a4a4a",
140
+ button_primary_background_fill_hover="#5a5a5a",
141
+ input_background_fill="#3a3a3a",
142
+ )
143
+
144
+ # -------------------------
145
+ # Minimal CSS (layout only)
146
+ # -------------------------
147
+ css = """
148
+ .image-preview img {
149
+ max-width: 512px;
150
+ max-height: 512px;
151
+ margin: 0 auto;
152
+ display: block;
153
+ object-fit: contain;
154
+ border-radius: 6px;
155
+ }
156
+ .app-header {
157
+ padding: 24px;
158
+ margin-bottom: 24px;
159
+ text-align: center;
160
+ }
161
+ .app-title {
162
+ font-size: 48px;
163
+ font-weight: 700;
164
+ }
165
+ .app-subtitle {
166
+ font-size: 24px;
167
+ opacity: 0.9;
168
+ }
169
+ .publication-links {
170
+ display: flex;
171
+ justify-content: center;
172
+ flex-wrap: wrap;
173
+ gap: 8px;
174
+ margin-top: 12px;
175
+ }
176
+ """
177
+
178
+ header_html = """
179
+ <div class="app-header">
180
+ <h1 class="app-title">Sapiens: Body-Part Segmentation</h1>
181
+ <h2 class="app-subtitle">ECCV 2024 (Oral)</h2>
182
+ <p>
183
+ Foundation models for human-centric vision tasks pretrained on
184
+ 300M human images. This demo showcases fine-tuned body-part
185
+ segmentation.
186
+ </p>
187
+ <div class="publication-links">
188
+ <a href="https://arxiv.org/abs/2408.12569">arXiv</a>
189
+ <a href="https://github.com/facebookresearch/sapiens">GitHub</a>
190
+ <a href="https://about.meta.com/realitylabs/codecavatars/sapiens/">Meta</a>
191
+ </div>
192
+ </div>
193
+ """
194
+
195
+ def process(image, model_name):
196
+ return self.image_processor.process_image(image, model_name)
197
+
198
+ with gr.Blocks(theme=theme, css=css) as demo:
199
+ gr.HTML(header_html)
200
+
201
+ with gr.Row():
202
+ with gr.Column():
203
+ input_image = gr.Image(
204
+ label="Input Image",
205
+ type="pil",
206
+ elem_classes="image-preview",
207
+ )
208
+
209
+ model_name = gr.Dropdown(
210
+ label="Model Size",
211
+ choices=list(Config.CHECKPOINTS.keys()),
212
+ value="1b",
213
+ )
214
+
215
+ gr.Examples(
216
+ inputs=input_image,
217
+ examples=[
218
+ os.path.join(Config.ASSETS_DIR, "images", img)
219
+ for img in os.listdir(
220
+ os.path.join(Config.ASSETS_DIR, "images")
221
+ )
222
+ ],
223
+ examples_per_page=14,
224
+ )
225
+
226
+ with gr.Column():
227
+ result_image = gr.Image(
228
+ label="Segmentation Result",
229
+ type="pil",
230
+ elem_classes="image-preview",
231
+ )
232
+ npy_output = gr.File(label="Segmentation (.npy)")
233
+ run_button = gr.Button("Run", variant="primary")
234
+
235
+ gr.Image(
236
+ os.path.join(Config.ASSETS_DIR, "palette.jpg"),
237
+ label="Class Palette",
238
+ type="filepath",
239
+ elem_classes="image-preview",
240
+ )
241
+
242
+ run_button.click(
243
+ fn=process,
244
+ inputs=[input_image, model_name],
245
+ outputs=[result_image, npy_output],
246
+ )
247
+
248
+ return demo
249
+
250
+
251
+ # =========================================================
252
+ # Entrypoint
253
+ # =========================================================
254
+
255
+ def main():
256
+ if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8:
257
+ torch.backends.cuda.matmul.allow_tf32 = True
258
+ torch.backends.cudnn.allow_tf32 = True
259
+
260
+ interface = GradioInterface()
261
+ demo = interface.create_interface()
262
+ demo.launch(server_name="0.0.0.0", share=False)
263
+
264
+
265
+ if __name__ == "__main__":
266
+ main()
267
+
assets/checkpoints/sapiens_0.3b_goliath_best_goliath_mIoU_7673_epoch_194_torchscript.pt2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:735a9a8d63fe8f3f6a4ca3d787de07e69b1f9708ad550e09bb33c9854b7eafbc
3
+ size 1358871599
assets/checkpoints/sapiens_0.6b_goliath_best_goliath_mIoU_7777_epoch_178_torchscript.pt2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:86aa2cb9d7310ba1cb1971026889f1d10d80ddf655d6028aea060aae94d82082
3
+ size 2685144079
assets/checkpoints/sapiens_1b_goliath_best_goliath_mIoU_7994_epoch_151_torchscript.pt2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33bba30f3de8d9cfd44e4eaa4817b1bfdd98c188edfc87fa7cc031ba0f4edc17
3
+ size 4716314057
assets/checkpoints/sapiens_2b_goliath_best_goliath_mIoU_8131_epoch_200_torchscript.pt2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:afe0970265f2af97f9eeb625036f147730d56820d6891803b13a278160c0f98a
3
+ size 8706620345
assets/images/68204.png ADDED

Git LFS Details

  • SHA256: 9b0268cb801ed164864a4b5f6d131e0ac5cc2fbd149a6467d5d0c97da47122c2
  • Pointer size: 132 Bytes
  • Size of remote file: 4.29 MB
assets/images/68210.png ADDED

Git LFS Details

  • SHA256: dbe5f80498af4ebd1ff09ae4184f37c20ba981e53bd554c3cc78d39ae0ee7fd7
  • Pointer size: 132 Bytes
  • Size of remote file: 3.93 MB
assets/images/68658.png ADDED

Git LFS Details

  • SHA256: 61a68b619bd17235e683324f2826ce0693322e45ab8c86f1c057851ecb333ac7
  • Pointer size: 132 Bytes
  • Size of remote file: 5.1 MB
assets/images/68666.png ADDED

Git LFS Details

  • SHA256: ea3047e6c2ccb485fdb3966aa2325e803cbf49c27c0bff00287b44bc16f18914
  • Pointer size: 132 Bytes
  • Size of remote file: 4.56 MB
assets/images/68691.png ADDED

Git LFS Details

  • SHA256: fae39e4055c1b297af7068cdddfeeba8d685363281b839d8c5afac1980204b57
  • Pointer size: 132 Bytes
  • Size of remote file: 3.74 MB
assets/images/68956.png ADDED

Git LFS Details

  • SHA256: eee1f27082b10999d0fa848121ecb06cda3386b1a864b9aa0f59ae78261f8908
  • Pointer size: 132 Bytes
  • Size of remote file: 4.15 MB
assets/images/pexels-amresh444-17315601.png ADDED

Git LFS Details

  • SHA256: 4e17ee1b229147e4b52e8348a6ef426bc9e9a2f90738e776e15b26b325abb9b3
  • Pointer size: 132 Bytes
  • Size of remote file: 3.5 MB
assets/images/pexels-gabby-k-6311686.png ADDED

Git LFS Details

  • SHA256: 3f10eded3fb05ab04b963f7b9fd2e183d8d4e81b20569b1c6b0653549639421f
  • Pointer size: 132 Bytes
  • Size of remote file: 3.65 MB
assets/images/pexels-julia-m-cameron-4145040.png ADDED

Git LFS Details

  • SHA256: 459cf0280667b028ffbca16aa11188780d7a0205c0defec02916ff3cbaeecb72
  • Pointer size: 132 Bytes
  • Size of remote file: 2.92 MB
assets/images/pexels-marcus-aurelius-6787357.png ADDED

Git LFS Details

  • SHA256: 7d35452f76492125eaf7d5783aa9fd6b0d5990ebe0579fe9dfd58a9d634f4955
  • Pointer size: 132 Bytes
  • Size of remote file: 3.3 MB
assets/images/pexels-mo-saeed-3616599-5409085.png ADDED

Git LFS Details

  • SHA256: 7c1ca7afd6c2a654e94ef59d5fb56fca4f3cde5fb5216f6b218c34a7b8c143dc
  • Pointer size: 132 Bytes
  • Size of remote file: 3.13 MB
assets/images/pexels-riedelmax-27355495.png ADDED

Git LFS Details

  • SHA256: 4141d2f5f718f162ea1f6710c06b28b5cb51fd69598fde35948f8f3491228164
  • Pointer size: 132 Bytes
  • Size of remote file: 3.73 MB
assets/images/pexels-sergeymakashin-5368660.png ADDED

Git LFS Details

  • SHA256: af8f5a8f26dd102d87d94c1be36ec903791fe8e6d951c68ebb9ebcfc6d7397bb
  • Pointer size: 132 Bytes
  • Size of remote file: 4.08 MB
assets/images/pexels-vinicius-wiesehofer-289347-4219918.png ADDED

Git LFS Details

  • SHA256: a6eef5eee15b81fe65ea95627e9a46040b9889466689b3c1ca6ed273e02fe84f
  • Pointer size: 132 Bytes
  • Size of remote file: 3.63 MB
assets/palette.jpg ADDED

Git LFS Details

  • SHA256: b17692ef3956cbc93376b0238e8256b0759544b694d03f612f21219f6d9c3877
  • Pointer size: 131 Bytes
  • Size of remote file: 313 kB
classes_and_palettes.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ORIGINAL_GOLIATH_CLASSES = (
2
+ "Background",
3
+ "Apparel",
4
+ "Chair",
5
+ "Eyeglass_Frame",
6
+ "Eyeglass_Lenses",
7
+ "Face_Neck",
8
+ "Hair",
9
+ "Headset",
10
+ "Left_Foot",
11
+ "Left_Hand",
12
+ "Left_Lower_Arm",
13
+ "Left_Lower_Leg",
14
+ "Left_Shoe",
15
+ "Left_Sock",
16
+ "Left_Upper_Arm",
17
+ "Left_Upper_Leg",
18
+ "Lower_Clothing",
19
+ "Lower_Spandex",
20
+ "Right_Foot",
21
+ "Right_Hand",
22
+ "Right_Lower_Arm",
23
+ "Right_Lower_Leg",
24
+ "Right_Shoe",
25
+ "Right_Sock",
26
+ "Right_Upper_Arm",
27
+ "Right_Upper_Leg",
28
+ "Torso",
29
+ "Upper_Clothing",
30
+ "Visible_Badge",
31
+ "Lower_Lip",
32
+ "Upper_Lip",
33
+ "Lower_Teeth",
34
+ "Upper_Teeth",
35
+ "Tongue",
36
+ )
37
+
38
+ ORIGINAL_GOLIATH_PALETTE = [
39
+ [50, 50, 50],
40
+ [255, 218, 0],
41
+ [102, 204, 0],
42
+ [14, 0, 204],
43
+ [0, 204, 160],
44
+ [128, 200, 255],
45
+ [255, 0, 109],
46
+ [0, 255, 36],
47
+ [189, 0, 204],
48
+ [255, 0, 218],
49
+ [0, 160, 204],
50
+ [0, 255, 145],
51
+ [204, 0, 131],
52
+ [182, 0, 255],
53
+ [255, 109, 0],
54
+ [0, 255, 255],
55
+ [72, 0, 255],
56
+ [204, 43, 0],
57
+ [204, 131, 0],
58
+ [255, 0, 0],
59
+ [72, 255, 0],
60
+ [189, 204, 0],
61
+ [182, 255, 0],
62
+ [102, 0, 204],
63
+ [32, 72, 204],
64
+ [0, 145, 255],
65
+ [14, 204, 0],
66
+ [0, 128, 72],
67
+ [204, 0, 43],
68
+ [235, 205, 119],
69
+ [115, 227, 112],
70
+ [157, 113, 143],
71
+ [132, 93, 50],
72
+ [82, 21, 114],
73
+ ]
74
+
75
+ ## 6 classes to remove
76
+ REMOVE_CLASSES = (
77
+ "Eyeglass_Frame",
78
+ "Eyeglass_Lenses",
79
+ "Visible_Badge",
80
+ "Chair",
81
+ "Lower_Spandex",
82
+ "Headset",
83
+ )
84
+
85
+ ## 34 - 6 = 28 classes left
86
+ GOLIATH_CLASSES = tuple(
87
+ [x for x in ORIGINAL_GOLIATH_CLASSES if x not in REMOVE_CLASSES]
88
+ )
89
+ GOLIATH_PALETTE = [
90
+ ORIGINAL_GOLIATH_PALETTE[idx]
91
+ for idx in range(len(ORIGINAL_GOLIATH_CLASSES))
92
+ if ORIGINAL_GOLIATH_CLASSES[idx] not in REMOVE_CLASSES
93
+ ]
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ numpy
3
+ torch
4
+ torchvision
5
+ matplotlib
6
+ pillow
7
+ spaces
8
+ opencv-python