File size: 6,681 Bytes
2b534de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
from pathlib import Path
import cv2
import numpy as np

def vis_tensors_A(l_tensor_or_named_tensor, path_grid, vis_batch_size=4, layout='auto'):
    """Visualize a list of tensors in a grid layout.
    Args:
        l_tensor_or_named_tensor: [tensor | (name, tensor), ..]. each tensor: B,(C,)H,W is in [-1,1] range
        path_grid: Path object for saving the grid visualization
        vis_batch_size: number of samples to visualize
        layout: 'BxI' (batch x images) or 'IxB' (images x batch) or 'auto'
    """
    import torch
    from torchvision.utils import make_grid, save_image
    path_grid = Path(path_grid)
    path_grid.parent.mkdir(parents=0, exist_ok=True)
    # Helper function to unnormalize and prepare images for saving
    def prepare_for_vis(tensor, ):
        if tensor is None:
            return None
        shape = tensor.shape
        assert shape[1]<=3
        if len(shape)==3 or shape[1]==1:
            is_mask = True
        else:  is_mask = False
        if is_mask:
            return tensor.repeat(1, 3, 1, 1).cpu()  # Expand mask to 3 channels
        else:
            return (tensor * 0.5 + 0.5).cpu()  # Unnormalize from [-1, 1] to [0, 1]
    named_tensors = []
    for tensor_or_named_tensor in l_tensor_or_named_tensor:
        if isinstance(tensor_or_named_tensor, tuple):
            name, tensor = tensor_or_named_tensor
        else:
            name = ""
            tensor = tensor_or_named_tensor
        if tensor is not None:
            named_tensors.append((name, prepare_for_vis(tensor.detach()[:vis_batch_size], )))
    # Make sure all tensors have the same spatial dimensions  
    all_shapes = [img.shape[2:] for _, img in named_tensors if img is not None]
    if len(set(all_shapes)) > 1:  # Pad images to match the largest dimensions
        max_h = max(shape[0] for shape in all_shapes)
        max_w = max(shape[1] for shape in all_shapes)
        for i in range(len(named_tensors)):
            name, img = named_tensors[i]
            if img is None:
                continue
            if img.shape[2] == max_h and img.shape[3] == max_w:
                continue
            pad_h = max_h - img.shape[2]
            pad_w = max_w - img.shape[3]
            named_tensors[i] = (name, torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), value=0))
    tensors = []
    for _, (name, tensor) in enumerate(named_tensors):
        tensor = tensor.detach()
        if name:
            for b in range(tensor.shape[0]):
                # Convert tensor to numpy for OpenCV
                img = tensor[b].permute(1, 2, 0).numpy()
                img = (img * 255).astype(np.uint8).copy()  # Make contiguous copy for OpenCV
                # Add text
                cv2.putText(img, name, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 
                           0.7, (0, 0, 0), 2, cv2.LINE_AA)
                cv2.putText(img, name, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 
                           0.7, (255, 255, 255), 1, cv2.LINE_AA)
                img_tensor = torch.from_numpy(img).permute(2, 0, 1) / 255.0 # Convert back to tensor
                tensors.append(img_tensor)
        else:
            for b in range(tensor.shape[0]):
                tensors.append(tensor[b])
    if tensors: # I*B,3,..
        all_images_flat = torch.stack(tensors) # I*B,3,..
        I = len(named_tensors)
        B = vis_batch_size
        if layout == 'auto':
            if B/I > 0.8:
                layout = 'IxB'
            else:
                layout = 'BxI'
        if layout == 'BxI':
            all_images_nonflat = all_images_flat.reshape(I, B, *all_images_flat.shape[1:])
            all_images_nonflat = all_images_nonflat.permute(1, 0, 2, 3, 4)
            all_images_flat = all_images_nonflat.reshape(-1, *all_images_flat.shape[1:])
            nrow = I
        else:  # 'IxB'
            nrow = B
        save_image(make_grid(all_images_flat, nrow=nrow), path_grid)
        print(f"{path_grid=}")

def visualize_landmarks(image, landmarks, save_path):
    """
    Draw landmarks on an image and save the result.
    
    Args:
        image: Input image as a numpy array (H,W,3) with values in [0,255]
        landmarks: Numpy array of shape (136,) or (68,2) containing 68 keypoint coordinates
        save_path: Path where the annotated image should be written
    """
    # Clone the image and ensure uint8 type
    image = image.copy().astype(np.uint8)
    
    # Ensure the image buffer is contiguous
    image = np.ascontiguousarray(image)
    
    # Reshape landmarks into (68,2) if needed
    if landmarks.shape[0] == 136:
        landmarks = landmarks.reshape(68, 2)
    
    # Draw each landmark point
    for (x, y) in landmarks:
        cv2.circle(image, (int(x), int(y)), 2, (0, 255, 0), -1)
    
    # Save the annotated image
    cv2.imwrite(save_path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR))

def visualize_headPose(img_path, yaw, pitch, roll, save_path):
    """Visualize pose angles on image using arrows
    Args:
        img_path: Path to input image
        yaw: Yaw angle in degrees
        pitch: Pitch angle in degrees
        roll: Roll angle in degrees
        save_path: Path to save visualization
    """
    import matplotlib.pyplot as plt
    img = cv2.imread(str(img_path))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    h, w = img.shape[:2]
    center = (w//2, h//2)
    
    plt.figure(figsize=(10, 10))
    plt.imshow(img)
    
    # Yaw (left-right)
    yaw_rad = np.radians(yaw)
    yaw_end = (center[0] + int(100 * np.sin(yaw_rad)), 
               center[1] - int(100 * np.cos(yaw_rad)))
    plt.arrow(center[0], center[1], yaw_end[0]-center[0], yaw_end[1]-center[1],
              color='r', width=2, head_width=20, label=f'Yaw: {yaw:.1f}°')
    
    # Pitch (up-down)
    pitch_rad = np.radians(pitch)
    pitch_end = (center[0] + int(100 * np.sin(pitch_rad)),
                 center[1] - int(100 * np.cos(pitch_rad)))
    plt.arrow(center[0], center[1], pitch_end[0]-center[0], pitch_end[1]-center[1],
              color='g', width=2, head_width=20, label=f'Pitch: {pitch:.1f}°')
    
    # Roll (tilt)
    roll_rad = np.radians(roll)
    roll_end = (center[0] + int(100 * np.cos(roll_rad)),
                center[1] + int(100 * np.sin(roll_rad)))
    plt.arrow(center[0], center[1], roll_end[0]-center[0], roll_end[1]-center[1],
              color='b', width=2, head_width=20, label=f'Roll: {roll:.1f}°')
    
    plt.legend()
    plt.axis('off')
    
    # Save visualization
    save_path = Path(save_path)
    save_path.parent.mkdir(parents=True, exist_ok=True)
    plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
    plt.close()
    print(f"{save_path=}")