File size: 5,656 Bytes
24870a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""
Simplified Human Parsing using a pretrained model.
This replaces the MediaPipe-based parsing with a proper semantic segmentation model.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import numpy as np
from PIL import Image
import os
import gdown

class SimpleHumanParser:
    """
    A simplified human parsing model using a lightweight segmentation approach.
    For production use, consider integrating SCHP or Graphonomy.
    """
    def __init__(self):
        self.device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
        if torch.cuda.is_available():
            self.device = torch.device('cuda')
        
        # For now, we'll use a heuristic-based approach
        # In a production system, you'd load a pretrained model here
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    def parse_image(self, image, pose_data):
        """
        Generate a semantic segmentation map for the person.
        
        Args:
            image: PIL Image
            pose_data: numpy array of pose keypoints (18, 2)
        
        Returns:
            numpy array of shape (H, W) with class labels
        """
        # Convert image to numpy
        img_np = np.array(image)
        h, w = img_np.shape[:2]
        
        # Initialize parse map
        parse_map = np.zeros((h, w), dtype=np.uint8)
        
        # Use pose to create better segmentation
        # This is still heuristic but more sophisticated than before
        
        # 1. Background (0)
        parse_map[:] = 0
        
        # 2. Face/Head region (1, 4, 13 in original)
        # Use nose, eyes, ears
        if pose_data[0][0] > 0:  # Nose exists
            nose = pose_data[0].astype(int)
            # Estimate head region
            head_radius = int(h * 0.08)  # Approximate head size
            y1 = max(0, nose[1] - head_radius * 2)
            y2 = min(h, nose[1] + head_radius // 2)
            x1 = max(0, nose[0] - head_radius)
            x2 = min(w, nose[0] + head_radius)
            parse_map[y1:y2, x1:x2] = 4  # Face
        
        # 3. Upper body (torso) - label 3 (upper clothes)
        # Use shoulders and hips
        r_shoulder = pose_data[2].astype(int)
        l_shoulder = pose_data[5].astype(int)
        r_hip = pose_data[8].astype(int)
        l_hip = pose_data[11].astype(int)
        
        if all(r_shoulder > 0) and all(l_shoulder > 0) and all(r_hip > 0) and all(l_hip > 0):
            # Create torso polygon
            torso_pts = np.array([
                r_shoulder,
                l_shoulder,
                l_hip,
                r_hip
            ], dtype=np.int32)
            
            # Expand the polygon slightly
            center = torso_pts.mean(axis=0)
            torso_pts = ((torso_pts - center) * 1.2 + center).astype(np.int32)
            
            # Fill torso
            from PIL import ImageDraw
            mask_img = Image.new('L', (w, h), 0)
            draw = ImageDraw.Draw(mask_img)
            draw.polygon([tuple(p) for p in torso_pts], fill=3)
            torso_mask = np.array(mask_img)
            parse_map[torso_mask == 3] = 3
        
        # 4. Arms - labels 5 (left arm), 6 (right arm)
        # Right arm: shoulder(2) -> elbow(3) -> wrist(4)
        if all(pose_data[2] > 0) and all(pose_data[3] > 0):
            self._draw_limb(parse_map, pose_data[2], pose_data[3], 6, w, h)
        if all(pose_data[3] > 0) and all(pose_data[4] > 0):
            self._draw_limb(parse_map, pose_data[3], pose_data[4], 6, w, h)
        
        # Left arm: shoulder(5) -> elbow(6) -> wrist(7)
        if all(pose_data[5] > 0) and all(pose_data[6] > 0):
            self._draw_limb(parse_map, pose_data[5], pose_data[6], 5, w, h)
        if all(pose_data[6] > 0) and all(pose_data[7] > 0):
            self._draw_limb(parse_map, pose_data[6], pose_data[7], 5, w, h)
        
        # 5. Legs - labels 9, 12 (pants/bottom)
        # Right leg
        if all(pose_data[8] > 0) and all(pose_data[9] > 0):
            self._draw_limb(parse_map, pose_data[8], pose_data[9], 9, w, h)
        if all(pose_data[9] > 0) and all(pose_data[10] > 0):
            self._draw_limb(parse_map, pose_data[9], pose_data[10], 9, w, h)
        
        # Left leg
        if all(pose_data[11] > 0) and all(pose_data[12] > 0):
            self._draw_limb(parse_map, pose_data[11], pose_data[12], 12, w, h)
        if all(pose_data[12] > 0) and all(pose_data[13] > 0):
            self._draw_limb(parse_map, pose_data[12], pose_data[13], 12, w, h)
        
        # 6. Hair (1) - region above face
        if pose_data[0][0] > 0:
            nose = pose_data[0].astype(int)
            hair_radius = int(h * 0.08)
            y1 = max(0, nose[1] - hair_radius * 3)
            y2 = nose[1] - hair_radius
            x1 = max(0, nose[0] - hair_radius)
            x2 = min(w, nose[0] + hair_radius)
            if y2 > y1:
                parse_map[y1:y2, x1:x2] = 1  # Hair
        
        return parse_map
    
    def _draw_limb(self, parse_map, pt1, pt2, label, w, h):
        """Draw a limb (line with thickness) on the parse map."""
        from PIL import ImageDraw
        mask_img = Image.new('L', (w, h), 0)
        draw = ImageDraw.Draw(mask_img)
        thickness = max(10, int(h * 0.03))
        draw.line([tuple(pt1.astype(int)), tuple(pt2.astype(int))], fill=label, width=thickness)
        limb_mask = np.array(mask_img)
        parse_map[limb_mask == label] = label