File size: 8,547 Bytes
f783161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1a407f
f783161
a1a407f
f783161
a1a407f
f783161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
import torch
from PIL import Image
from app_utils import *
import torch.nn.functional as F
import numpy as np
from torchvision import transforms as TF

from scipy.special import i0
from scipy.optimize import curve_fit
from scipy.integrate import trapezoid
from functools import partial

def von_mises_pdf_alpha_numpy(alpha, x, mu, kappa):
    normalization = 2 * np.pi
    pdf = np.exp(kappa * np.cos(alpha * (x - mu))) / normalization
    return pdf

def val_fit_alpha(distribute):
    fit_alphas = []
    for y_noise in distribute:
        x = np.linspace(0, 2 * np.pi, 360)
        y_noise /= trapezoid(y_noise, x) + 1e-8
        
        initial_guess = [x[np.argmax(y_noise)], 1]
        
        # support 1,2,4
        alphas = [1.0, 2.0, 4.0]
        saved_params = []
        saved_r_squared = []

        for alpha in alphas:
            try:
                von_mises_pdf_alpha_partial = partial(von_mises_pdf_alpha_numpy, alpha)
                params, covariance = curve_fit(von_mises_pdf_alpha_partial, x, y_noise, p0=initial_guess)

                residuals = y_noise - von_mises_pdf_alpha_partial(x, *params)
                ss_res = np.sum(residuals**2)
                ss_tot = np.sum((y_noise - np.mean(y_noise))**2)
                r_squared = 1 - (ss_res / (ss_tot+1e-8))

                saved_params.append(params)
                saved_r_squared.append(r_squared)
                if r_squared > 0.8:
                    break
            except:
                saved_params.append((0.,0.))
                saved_r_squared.append(0.)

        max_index = np.argmax(saved_r_squared)
        alpha = alphas[max_index]
        mu_fit, kappa_fit = saved_params[max_index]
        r_squared = saved_r_squared[max_index]
        
        if alpha == 1. and kappa_fit>=0.6 and r_squared>=0.45:
            pass
        elif alpha == 2. and kappa_fit>=0.4 and r_squared>=0.45:
            pass
        elif alpha == 4. and kappa_fit>=0.25 and r_squared>=0.45:
            pass
        else:
            alpha=0.
        fit_alphas.append(alpha)
    return torch.tensor(fit_alphas)

def preprocess_images(image_list, mode="crop"):

    # Check for empty list
    if len(image_list) == 0:
        raise ValueError("At least 1 image is required")
    
    # Validate mode
    if mode not in ["crop", "pad"]:
        raise ValueError("Mode must be either 'crop' or 'pad'")

    images = []
    shapes = set()
    to_tensor = TF.ToTensor()
    target_size = 518

    # First process all images and collect their shapes
    # for image_path in image_path_list:
    for img in image_list:
        # If there's an alpha channel, blend onto white background:
        if img.mode == "RGBA":
            # Create white background
            background = Image.new("RGBA", img.size, (255, 255, 255, 255))
            # Alpha composite onto the white background
            img = Image.alpha_composite(background, img)

        # Now convert to "RGB" (this step assigns white for transparent areas)
        img = img.convert("RGB")
        width, height = img.size
        
        if mode == "pad":
            # Make the largest dimension 518px while maintaining aspect ratio
            if width >= height:
                new_width = target_size
                new_height = round(height * (new_width / width) / 14) * 14  # Make divisible by 14
            else:
                new_height = target_size
                new_width = round(width * (new_height / height) / 14) * 14  # Make divisible by 14
        else:  # mode == "crop"
            # Original behavior: set width to 518px
            new_width = target_size
            # Calculate height maintaining aspect ratio, divisible by 14
            new_height = round(height * (new_width / width) / 14) * 14

        # Resize with new dimensions (width, height)
        try:
            img = img.resize((new_width, new_height), Image.Resampling.BICUBIC)
            img = to_tensor(img)  # Convert to tensor (0, 1)
        except Exception as e:
            print(e)
            print(width, height)
            print(new_width, new_height)
            assert False

        # Center crop height if it's larger than 518 (only in crop mode)
        if mode == "crop" and new_height > target_size:
            start_y = (new_height - target_size) // 2
            img = img[:, start_y : start_y + target_size, :]
        
        # For pad mode, pad to make a square of target_size x target_size
        if mode == "pad":
            h_padding = target_size - img.shape[1]
            w_padding = target_size - img.shape[2]
            
            if h_padding > 0 or w_padding > 0:
                pad_top = h_padding // 2
                pad_bottom = h_padding - pad_top
                pad_left = w_padding // 2
                pad_right = w_padding - pad_left
                
                # Pad with white (value=1.0)
                img = torch.nn.functional.pad(
                    img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
                )

        shapes.add((img.shape[1], img.shape[2]))
        images.append(img)

    # Check if we have different shapes
    # In theory our model can also work well with different shapes
    if len(shapes) > 1:
        print(f"Warning: Found images with different shapes: {shapes}")
        # Find maximum dimensions
        max_height = max(shape[0] for shape in shapes)
        max_width = max(shape[1] for shape in shapes)

        # Pad images if necessary
        padded_images = []
        for img in images:
            h_padding = max_height - img.shape[1]
            w_padding = max_width - img.shape[2]

            if h_padding > 0 or w_padding > 0:
                pad_top = h_padding // 2
                pad_bottom = h_padding - pad_top
                pad_left = w_padding // 2
                pad_right = w_padding - pad_left

                img = torch.nn.functional.pad(
                    img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
                )
            padded_images.append(img)
        images = padded_images

    images = torch.stack(images)  # concatenate images

    # Ensure correct shape when single image
    if len(image_list) == 1:
        # Verify shape is (1, C, H, W)
        if images.dim() == 3:
            images = images.unsqueeze(0)

    return images

@torch.no_grad()
def inf_single_batch(model, batch):
    device = model.get_device()
    batch_img_inputs = batch # (B, S, 3, H, W)
    # print(batch_img_inputs.shape)
    B, S, C, H, W = batch_img_inputs.shape
    pose_enc = model(batch_img_inputs) # (B, S, D) S = 1
    
    pose_enc = pose_enc.view(B*S, -1)
    angle_az_pred = torch.argmax(pose_enc[:, 0:360]       , dim=-1)
    angle_el_pred = torch.argmax(pose_enc[:, 360:360+180] , dim=-1) - 90
    angle_ro_pred = torch.argmax(pose_enc[:, 360+180:360+180+360] , dim=-1) - 180
    
    # ori_val
    # trained with BCE loss
    distribute = F.sigmoid(pose_enc[:, 0:360]).cpu().float().numpy()
    # trained with CE loss
    # distribute = pose_enc[:, 0:360].cpu().float().numpy()
    alpha_pred = val_fit_alpha(distribute = distribute)

    # ref_val
    if S > 1:
        ref_az_pred = angle_az_pred.reshape(B,S)[:,0]
        ref_el_pred = angle_el_pred.reshape(B,S)[:,0]
        ref_ro_pred = angle_ro_pred.reshape(B,S)[:,0]
        ref_alpha_pred = alpha_pred.reshape(B,S)[:,0]
        rel_az_pred = angle_az_pred.reshape(B,S)[:,1]
        rel_el_pred = angle_el_pred.reshape(B,S)[:,1]
        rel_ro_pred = angle_ro_pred.reshape(B,S)[:,1]
    else:
        ref_az_pred = angle_az_pred[0]
        ref_el_pred = angle_el_pred[0]
        ref_ro_pred = angle_ro_pred[0]
        ref_alpha_pred = alpha_pred[0]
        rel_az_pred = 0.
        rel_el_pred = 0.
        rel_ro_pred = 0.

    ans_dict = {
        'ref_az_pred': ref_az_pred,
        'ref_el_pred': ref_el_pred,
        'ref_ro_pred': ref_ro_pred,
        'ref_alpha_pred' : ref_alpha_pred,
        'rel_az_pred'  : rel_az_pred,
        'rel_el_pred'  : rel_el_pred,
        'rel_ro_pred'  : rel_ro_pred,
    }
    
    return ans_dict 

# input PIL Image
@torch.no_grad()
def inf_single_case(model, image_ref, image_tgt):
    if image_tgt is None:
        image_list = [image_ref]
    else:
        image_list = [image_ref, image_tgt]
    image_tensors = preprocess_images(image_list, mode="pad").to(model.get_device())
    ans_dict = inf_single_batch(model=model, batch=image_tensors.unsqueeze(0))
    print(ans_dict)
    return ans_dict