Josh Brown Kramer commited on
Commit
36c1e20
·
1 Parent(s): a56695f

face parsing demo working locally

Browse files
Files changed (1) hide show
  1. faceparsing.py +13 -5
faceparsing.py CHANGED
@@ -1,6 +1,7 @@
1
  import torch
2
  from torch import nn
3
  from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
 
4
 
5
  from PIL import Image
6
  import matplotlib.pyplot as plt
@@ -38,17 +39,24 @@ def get_face_mask(image):
38
 
39
  # move to CPU to visualize in matplotlib
40
  labels_viz = labels.cpu().numpy()
 
 
 
41
 
42
  #Map to something more colorful. Use a color map to map the labels to a color.
43
  #Create a color map for colors 0 through 18
44
  color_map = plt.get_cmap('tab20')
45
- #Map the labels to a color
46
- colors = color_map(labels_viz)
 
 
47
 
48
- #Convert to PIL Image
49
- colors_pil = Image.fromarray((colors * 255).astype(np.uint8))
50
 
 
 
 
51
 
52
- return labels_viz
 
53
 
54
 
 
1
  import torch
2
  from torch import nn
3
  from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
4
+ import numpy as np
5
 
6
  from PIL import Image
7
  import matplotlib.pyplot as plt
 
39
 
40
  # move to CPU to visualize in matplotlib
41
  labels_viz = labels.cpu().numpy()
42
+
43
+ # Debug: print label statistics
44
+ print(f"Labels min: {labels_viz.min()}, max: {labels_viz.max()}, unique: {np.unique(labels_viz)}")
45
 
46
  #Map to something more colorful. Use a color map to map the labels to a color.
47
  #Create a color map for colors 0 through 18
48
  color_map = plt.get_cmap('tab20')
49
+ #Map the labels to a color - normalize labels to 0-1 range for the colormap
50
+ # Face parsing models typically have 19 classes (0-18), so normalize by 18
51
+ normalized_labels = labels_viz.astype(np.float32) / 18.0
52
+ colors = color_map(normalized_labels)
53
 
 
 
54
 
55
+ #Convert to PIL Image - take only RGB channels (drop alpha)
56
+ colors_rgb = colors[:, :, :3] # Remove alpha channel
57
+ colors_pil = Image.fromarray((colors_rgb * 255).astype(np.uint8))
58
 
59
+
60
+ return colors_pil
61
 
62