carpedm20 commited on
Commit
3dc508a
·
verified ·
1 Parent(s): f6e4f38

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +58 -124
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import os
2
- import cv2
3
  import gradio as gr
4
  import numpy as np
5
  import torch
@@ -9,7 +8,7 @@ from PIL import Image
9
  import tempfile
10
 
11
  from gradio.themes.utils import sizes
12
- from classes_and_palettes import GOLIATH_PALETTE, GOLIATH_CLASSES
13
 
14
 
15
  # =========================================================
@@ -34,32 +33,22 @@ class ModelManager:
34
  _cache = {}
35
 
36
  @staticmethod
37
- def load_model(checkpoint_name: str):
38
- if checkpoint_name in ModelManager._cache:
39
- return ModelManager._cache[checkpoint_name]
40
-
41
- checkpoint_path = os.path.join(
42
- Config.CHECKPOINTS_DIR,
43
- Config.CHECKPOINTS[checkpoint_name],
44
- )
45
- model = torch.jit.load(checkpoint_path)
46
- model.eval()
47
- model.to("cuda")
48
- ModelManager._cache[checkpoint_name] = model
49
  return model
50
 
51
  @staticmethod
52
  @torch.inference_mode()
53
- def run_model(model, input_tensor, height, width):
54
- output = model(input_tensor)
55
- output = F.interpolate(
56
- output,
57
- size=(height, width),
58
- mode="bilinear",
59
- align_corners=False,
60
- )
61
- _, preds = torch.max(output, 1)
62
- return preds
63
 
64
 
65
  # =========================================================
@@ -68,7 +57,7 @@ class ModelManager:
68
 
69
  class ImageProcessor:
70
  def __init__(self):
71
- self.transform_fn = transforms.Compose([
72
  transforms.Resize((1024, 768)),
73
  transforms.ToTensor(),
74
  transforms.Normalize(
@@ -77,40 +66,37 @@ class ImageProcessor:
77
  ),
78
  ])
79
 
80
- def process_image(self, image: Image.Image, model_name: str):
81
  model = ModelManager.load_model(model_name)
82
- input_tensor = self.transform_fn(image).unsqueeze(0).to("cuda")
83
-
84
- preds = ModelManager.run_model(
85
- model,
86
- input_tensor,
87
- image.height,
88
- image.width,
89
- )
90
 
91
- mask = preds.squeeze(0).cpu().numpy()
92
- blended_image = self.visualize_pred_with_overlay(image, mask)
93
 
 
94
  npy_path = tempfile.mktemp(suffix=".npy")
95
  np.save(npy_path, mask)
96
 
97
- return blended_image, npy_path
 
98
 
99
- @staticmethod
100
- def visualize_pred_with_overlay(img, sem_seg, alpha=0.5):
101
- img_np = np.array(img.convert("RGB"))
102
- sem_seg = np.array(sem_seg)
 
 
 
103
 
104
- num_classes = len(GOLIATH_CLASSES)
105
- ids = np.unique(sem_seg)
106
- ids = ids[ids < num_classes]
107
 
108
- overlay = np.zeros((*sem_seg.shape, 3), dtype=np.uint8)
109
- for label in ids:
110
- overlay[sem_seg == label] = GOLIATH_PALETTE[label]
111
 
112
- blended = np.uint8(img_np * (1 - alpha) + overlay * alpha)
113
- return Image.fromarray(blended)
114
 
115
 
116
  # =========================================================
@@ -119,12 +105,9 @@ class ImageProcessor:
119
 
120
  class GradioInterface:
121
  def __init__(self):
122
- self.image_processor = ImageProcessor()
123
 
124
- def create_interface(self):
125
- # -------------------------
126
- # Theme (modern Gradio)
127
- # -------------------------
128
  theme = gr.themes.Soft(
129
  primary_hue="neutral",
130
  secondary_hue="slate",
@@ -135,28 +118,14 @@ class GradioInterface:
135
  body_background_fill="#1a1a1a",
136
  body_text_color="#fafafa",
137
  block_background_fill="#2a2a2a",
138
- block_border_color="#333333",
139
- button_primary_background_fill="#4a4a4a",
140
- button_primary_background_fill_hover="#5a5a5a",
141
- input_background_fill="#3a3a3a",
142
  )
143
 
144
- # -------------------------
145
- # Minimal CSS (layout only)
146
- # -------------------------
147
  css = """
148
- .image-preview img {
149
- max-width: 512px;
150
- max-height: 512px;
151
- margin: 0 auto;
152
- display: block;
153
- object-fit: contain;
154
- border-radius: 6px;
155
- }
156
  .app-header {
157
  padding: 24px;
158
- margin-bottom: 24px;
159
  text-align: center;
 
160
  }
161
  .app-title {
162
  font-size: 48px;
@@ -166,44 +135,27 @@ class GradioInterface:
166
  font-size: 24px;
167
  opacity: 0.9;
168
  }
169
- .publication-links {
170
- display: flex;
171
- justify-content: center;
172
- flex-wrap: wrap;
173
- gap: 8px;
174
- margin-top: 12px;
175
- }
176
  """
177
 
178
- header_html = """
179
  <div class="app-header">
180
- <h1 class="app-title">Sapiens: Body-Part Segmentation</h1>
181
  <h2 class="app-subtitle">ECCV 2024 (Oral)</h2>
182
- <p>
183
- Foundation models for human-centric vision tasks pretrained on
184
- 300M human images. This demo showcases fine-tuned body-part
185
- segmentation.
186
- </p>
187
- <div class="publication-links">
188
- <a href="https://arxiv.org/abs/2408.12569">arXiv</a>
189
- <a href="https://github.com/facebookresearch/sapiens">GitHub</a>
190
- <a href="https://about.meta.com/realitylabs/codecavatars/sapiens/">Meta</a>
191
- </div>
192
  </div>
193
  """
194
 
195
- def process(image, model_name):
196
- return self.image_processor.process_image(image, model_name)
197
 
198
  with gr.Blocks(theme=theme, css=css) as demo:
199
- gr.HTML(header_html)
200
 
201
  with gr.Row():
202
- with gr.Column():
203
  input_image = gr.Image(
204
  label="Input Image",
205
  type="pil",
206
- elem_classes="image-preview",
207
  )
208
 
209
  model_name = gr.Dropdown(
@@ -212,37 +164,21 @@ class GradioInterface:
212
  value="1b",
213
  )
214
 
215
- gr.Examples(
216
- inputs=input_image,
217
- examples=[
218
- os.path.join(Config.ASSETS_DIR, "images", img)
219
- for img in os.listdir(
220
- os.path.join(Config.ASSETS_DIR, "images")
221
- )
222
- ],
223
- examples_per_page=14,
224
- )
225
 
226
- with gr.Column():
227
- result_image = gr.Image(
228
  label="Segmentation Result",
229
- type="pil",
230
- elem_classes="image-preview",
231
- )
232
- npy_output = gr.File(label="Segmentation (.npy)")
233
- run_button = gr.Button("Run", variant="primary")
234
-
235
- gr.Image(
236
- os.path.join(Config.ASSETS_DIR, "palette.jpg"),
237
- label="Class Palette",
238
- type="filepath",
239
- elem_classes="image-preview",
240
  )
241
 
242
- run_button.click(
243
- fn=process,
 
 
244
  inputs=[input_image, model_name],
245
- outputs=[result_image, npy_output],
246
  )
247
 
248
  return demo
@@ -257,11 +193,9 @@ def main():
257
  torch.backends.cuda.matmul.allow_tf32 = True
258
  torch.backends.cudnn.allow_tf32 = True
259
 
260
- interface = GradioInterface()
261
- demo = interface.create_interface()
262
- demo.launch(server_name="0.0.0.0", share=False)
263
 
264
 
265
  if __name__ == "__main__":
266
  main()
267
-
 
1
  import os
 
2
  import gradio as gr
3
  import numpy as np
4
  import torch
 
8
  import tempfile
9
 
10
  from gradio.themes.utils import sizes
11
+ from classes_and_palettes import GOLIATH_CLASSES
12
 
13
 
14
  # =========================================================
 
33
  _cache = {}
34
 
35
  @staticmethod
36
+ def load_model(name: str):
37
+ if name in ModelManager._cache:
38
+ return ModelManager._cache[name]
39
+
40
+ path = os.path.join(Config.CHECKPOINTS_DIR, Config.CHECKPOINTS[name])
41
+ model = torch.jit.load(path)
42
+ model.eval().to("cuda")
43
+ ModelManager._cache[name] = model
 
 
 
 
44
  return model
45
 
46
  @staticmethod
47
  @torch.inference_mode()
48
+ def run(model, x, h, w):
49
+ out = model(x)
50
+ out = F.interpolate(out, size=(h, w), mode="bilinear", align_corners=False)
51
+ return out.argmax(1)
 
 
 
 
 
 
52
 
53
 
54
  # =========================================================
 
57
 
58
  class ImageProcessor:
59
  def __init__(self):
60
+ self.tf = transforms.Compose([
61
  transforms.Resize((1024, 768)),
62
  transforms.ToTensor(),
63
  transforms.Normalize(
 
66
  ),
67
  ])
68
 
69
+ def process(self, image: Image.Image, model_name: str):
70
  model = ModelManager.load_model(model_name)
71
+ x = self.tf(image).unsqueeze(0).to("cuda")
 
 
 
 
 
 
 
72
 
73
+ pred = ModelManager.run(model, x, image.height, image.width)
74
+ mask = pred.squeeze(0).cpu().numpy()
75
 
76
+ # Save raw mask
77
  npy_path = tempfile.mktemp(suffix=".npy")
78
  np.save(npy_path, mask)
79
 
80
+ # Build AnnotatedImage output
81
+ annotations = self._build_annotations(mask)
82
 
83
+ return (image, annotations), npy_path
84
+
85
+ def _build_annotations(self, mask: np.ndarray):
86
+ annotations = []
87
+ for class_id in np.unique(mask):
88
+ if class_id >= len(GOLIATH_CLASSES):
89
+ continue
90
 
91
+ binary_mask = (mask == class_id).astype(np.uint8)
92
+ if binary_mask.sum() == 0:
93
+ continue
94
 
95
+ annotations.append(
96
+ (binary_mask, GOLIATH_CLASSES[class_id])
97
+ )
98
 
99
+ return annotations
 
100
 
101
 
102
  # =========================================================
 
105
 
106
  class GradioInterface:
107
  def __init__(self):
108
+ self.processor = ImageProcessor()
109
 
110
+ def create(self):
 
 
 
111
  theme = gr.themes.Soft(
112
  primary_hue="neutral",
113
  secondary_hue="slate",
 
118
  body_background_fill="#1a1a1a",
119
  body_text_color="#fafafa",
120
  block_background_fill="#2a2a2a",
121
+ block_border_color="#333",
 
 
 
122
  )
123
 
 
 
 
124
  css = """
 
 
 
 
 
 
 
 
125
  .app-header {
126
  padding: 24px;
 
127
  text-align: center;
128
+ margin-bottom: 24px;
129
  }
130
  .app-title {
131
  font-size: 48px;
 
135
  font-size: 24px;
136
  opacity: 0.9;
137
  }
 
 
 
 
 
 
 
138
  """
139
 
140
+ header = """
141
  <div class="app-header">
142
+ <h1 class="app-title">Sapiens Body-Part Segmentation</h1>
143
  <h2 class="app-subtitle">ECCV 2024 (Oral)</h2>
144
+ <p>Foundation model fine-tuned for dense human part segmentation.</p>
 
 
 
 
 
 
 
 
 
145
  </div>
146
  """
147
 
148
+ def run(image, model):
149
+ return self.processor.process(image, model)
150
 
151
  with gr.Blocks(theme=theme, css=css) as demo:
152
+ gr.HTML(header)
153
 
154
  with gr.Row():
155
+ with gr.Column(scale=1):
156
  input_image = gr.Image(
157
  label="Input Image",
158
  type="pil",
 
159
  )
160
 
161
  model_name = gr.Dropdown(
 
164
  value="1b",
165
  )
166
 
167
+ run_btn = gr.Button("Run Segmentation", variant="primary")
 
 
 
 
 
 
 
 
 
168
 
169
+ with gr.Column(scale=2):
170
+ annotated = gr.AnnotatedImage(
171
  label="Segmentation Result",
172
+ show_legend=True,
173
+ height=512,
 
 
 
 
 
 
 
 
 
174
  )
175
 
176
+ mask_file = gr.File(label="Raw Mask (.npy)")
177
+
178
+ run_btn.click(
179
+ fn=run,
180
  inputs=[input_image, model_name],
181
+ outputs=[annotated, mask_file],
182
  )
183
 
184
  return demo
 
193
  torch.backends.cuda.matmul.allow_tf32 = True
194
  torch.backends.cudnn.allow_tf32 = True
195
 
196
+ app = GradioInterface().create()
197
+ app.launch(server_name="0.0.0.0", share=False)
 
198
 
199
 
200
  if __name__ == "__main__":
201
  main()