File size: 7,276 Bytes
f5fcafb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
"""
Fixed MatAnyone Inference Core
Removes tensor-to-numpy conversion bugs that cause F.pad() errors
"""

import torch
import torch.nn.functional as F
import numpy as np
from typing import Optional, Union, Tuple


def pad_divide_by(in_tensor: torch.Tensor, d: int) -> Tuple[torch.Tensor, Tuple[int, int, int, int]]:
    """
    FIXED VERSION: Ensures tensor input stays as tensor
    """
    if not isinstance(in_tensor, torch.Tensor):
        raise TypeError(f"Expected torch.Tensor, got {type(in_tensor)}")
    
    h, w = in_tensor.shape[-2:]
    
    # Calculate padding needed
    new_h = (h + d - 1) // d * d
    new_w = (w + d - 1) // d * d
    
    lh, uh = (new_h - h) // 2, (new_h - h) // 2 + (new_h - h) % 2
    lw, uw = (new_w - w) // 2, (new_w - w) // 2 + (new_w - w) % 2
    
    pad_array = (lw, uw, lh, uh)
    
    # CRITICAL FIX: Ensure tensor stays as tensor
    out = F.pad(in_tensor, pad_array, mode='reflect')
    
    return out, pad_array


def unpad_tensor(in_tensor: torch.Tensor, pad: Tuple[int, int, int, int]) -> torch.Tensor:
    """Remove padding from tensor"""
    if not isinstance(in_tensor, torch.Tensor):
        raise TypeError(f"Expected torch.Tensor, got {type(in_tensor)}")
    
    lw, uw, lh, uh = pad
    h, w = in_tensor.shape[-2:]
    
    # Remove padding
    if lh > 0:
        in_tensor = in_tensor[..., lh:, :]
    if uh > 0:
        in_tensor = in_tensor[..., :-uh, :]
    if lw > 0:
        in_tensor = in_tensor[..., :, lw:]
    if uw > 0:
        in_tensor = in_tensor[..., :, :-uw]
    
    return in_tensor


class InferenceCore:
    """
    FIXED MatAnyone Inference Core
    Handles video matting with proper tensor operations
    """
    
    def __init__(self, model: torch.nn.Module):
        self.model = model
        self.model.eval()
        self.device = next(model.parameters()).device
        self.pad = None
        
        # Memory storage for temporal consistency
        self.image_feature_store = {}
        self.frame_count = 0
    
    def _ensure_tensor_format(self, 
                             image: Union[torch.Tensor, np.ndarray], 
                             prob: Optional[Union[torch.Tensor, np.ndarray]] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """
        CRITICAL FIX: Ensure all inputs are properly formatted tensors
        """
        # Convert image to tensor if needed
        if isinstance(image, np.ndarray):
            if image.ndim == 3 and image.shape[-1] == 3:  # HWC format
                image = torch.from_numpy(image.transpose(2, 0, 1)).float()  # Convert to CHW
            elif image.ndim == 3 and image.shape[0] == 3:  # CHW format
                image = torch.from_numpy(image).float()
            else:
                raise ValueError(f"Unexpected image shape: {image.shape}")
        
        # Ensure image is on correct device and has correct format
        if not isinstance(image, torch.Tensor):
            raise TypeError(f"Image must be tensor after conversion, got {type(image)}")
        
        image = image.float().to(self.device)
        
        # Ensure CHW format (3, H, W)
        if image.ndim == 3 and image.shape[0] == 3:
            pass  # Already correct
        elif image.ndim == 4 and image.shape[0] == 1 and image.shape[1] == 3:
            image = image.squeeze(0)  # Remove batch dimension
        else:
            raise ValueError(f"Image must be (3,H,W) or (1,3,H,W), got {image.shape}")
        
        # Handle probability mask if provided
        if prob is not None:
            if isinstance(prob, np.ndarray):
                prob = torch.from_numpy(prob).float()
            
            if not isinstance(prob, torch.Tensor):
                raise TypeError(f"Prob must be tensor after conversion, got {type(prob)}")
            
            prob = prob.float().to(self.device)
            
            # Ensure HW format for prob
            while prob.ndim > 2:
                prob = prob.squeeze(0)
            
            if prob.ndim != 2:
                raise ValueError(f"Prob must be (H,W) after processing, got {prob.shape}")
        
        return image, prob
    
    def step(self, 
             image: Union[torch.Tensor, np.ndarray], 
             prob: Optional[Union[torch.Tensor, np.ndarray]] = None,
             **kwargs) -> torch.Tensor:
        """
        FIXED step method with proper tensor handling
        """
        # Convert inputs to proper tensor format
        image, prob = self._ensure_tensor_format(image, prob)
        
        with torch.no_grad():
            # Pad image for processing
            image_padded, self.pad = pad_divide_by(image, 16)
            
            # Add batch dimension for model
            image_batch = image_padded.unsqueeze(0)  # (1, 3, H_pad, W_pad)
            
            if prob is not None:
                # Pad probability mask to match image
                h_pad, w_pad = image_padded.shape[-2:]
                h_orig, w_orig = prob.shape
                
                # Resize prob to match padded image size
                prob_resized = F.interpolate(
                    prob.unsqueeze(0).unsqueeze(0),  # (1, 1, H, W)
                    size=(h_pad, w_pad),
                    mode='bilinear',
                    align_corners=False
                ).squeeze()  # (H_pad, W_pad)
                
                prob_batch = prob_resized.unsqueeze(0).unsqueeze(0)  # (1, 1, H_pad, W_pad)
                
                # Forward pass with probability guidance
                try:
                    if hasattr(self.model, 'forward_with_prob'):
                        output = self.model.forward_with_prob(image_batch, prob_batch)
                    else:
                        # Fallback: concatenate prob as additional channel
                        input_tensor = torch.cat([image_batch, prob_batch], dim=1)  # (1, 4, H_pad, W_pad)
                        output = self.model(input_tensor)
                except Exception:
                    # Final fallback: just use image
                    output = self.model(image_batch)
            else:
                # Forward pass without probability guidance
                output = self.model(image_batch)
            
            # Extract alpha channel (assume model outputs alpha as last channel or single channel)
            if output.shape[1] == 1:
                alpha = output.squeeze(1)  # (1, H_pad, W_pad)
            elif output.shape[1] > 1:
                alpha = output[:, -1:, :, :]  # Take last channel as alpha
            else:
                raise ValueError(f"Unexpected model output shape: {output.shape}")
            
            # Remove padding
            alpha_unpadded = unpad_tensor(alpha, self.pad)
            
            # Remove batch dimension and ensure 2D output
            alpha_final = alpha_unpadded.squeeze(0)  # (H, W)
            
            # Ensure values are in [0, 1] range
            alpha_final = torch.clamp(alpha_final, 0.0, 1.0)
            
            self.frame_count += 1
            
            return alpha_final
    
    def clear_memory(self):
        """Clear stored features for memory management"""
        self.image_feature_store.clear()
        self.frame_count = 0