MogensR commited on
Commit
bbb6939
·
verified ·
1 Parent(s): f5fcafb

Create matanyone_fixed/utils/tensor_utils.py

Browse files
Files changed (1) hide show
  1. matanyone_fixed/utils/tensor_utils.py +233 -0
matanyone_fixed/utils/tensor_utils.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Fixed MatAnyone Tensor Utilities
3
+ Ensures all tensor operations remain in tensor format
4
+ """
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import numpy as np
9
+ from typing import Tuple, Union
10
+
11
+
12
+ def pad_divide_by(in_tensor: torch.Tensor, d: int) -> Tuple[torch.Tensor, Tuple[int, int, int, int]]:
13
+ """
14
+ FIXED VERSION: Pad tensor to be divisible by d
15
+
16
+ Args:
17
+ in_tensor: Input tensor (..., H, W)
18
+ d: Divisor value
19
+
20
+ Returns:
21
+ padded_tensor: Padded tensor
22
+ pad_info: Padding information (left, right, top, bottom)
23
+ """
24
+ if not isinstance(in_tensor, torch.Tensor):
25
+ raise TypeError(f"Expected torch.Tensor, got {type(in_tensor)} - this is the source of F.pad() errors!")
26
+
27
+ # Get spatial dimensions
28
+ h, w = in_tensor.shape[-2:]
29
+
30
+ # Calculate required padding
31
+ new_h = ((h + d - 1) // d) * d
32
+ new_w = ((w + d - 1) // d) * d
33
+
34
+ pad_h = new_h - h
35
+ pad_w = new_w - w
36
+
37
+ # Split padding evenly on both sides
38
+ pad_top = pad_h // 2
39
+ pad_bottom = pad_h - pad_top
40
+ pad_left = pad_w // 2
41
+ pad_right = pad_w - pad_left
42
+
43
+ # PyTorch padding format: (left, right, top, bottom)
44
+ pad_array = (pad_left, pad_right, pad_top, pad_bottom)
45
+
46
+ # CRITICAL: Ensure input is tensor before F.pad
47
+ out = F.pad(in_tensor, pad_array, mode='reflect')
48
+
49
+ return out, pad_array
50
+
51
+
52
+ def unpad_tensor(padded_tensor: torch.Tensor, pad_info: Tuple[int, int, int, int]) -> torch.Tensor:
53
+ """
54
+ Remove padding from tensor
55
+
56
+ Args:
57
+ padded_tensor: Padded tensor
58
+ pad_info: Padding information (left, right, top, bottom)
59
+
60
+ Returns:
61
+ unpadded_tensor: Original size tensor
62
+ """
63
+ if not isinstance(padded_tensor, torch.Tensor):
64
+ raise TypeError(f"Expected torch.Tensor, got {type(padded_tensor)}")
65
+
66
+ pad_left, pad_right, pad_top, pad_bottom = pad_info
67
+
68
+ # Get current dimensions
69
+ h, w = padded_tensor.shape[-2:]
70
+
71
+ # Calculate crop boundaries
72
+ top = pad_top
73
+ bottom = h - pad_bottom if pad_bottom > 0 else h
74
+ left = pad_left
75
+ right = w - pad_right if pad_right > 0 else w
76
+
77
+ # Crop tensor
78
+ unpadded = padded_tensor[..., top:bottom, left:right]
79
+
80
+ return unpadded
81
+
82
+
83
+ def ensure_tensor(input_data: Union[torch.Tensor, np.ndarray], device: torch.device = None) -> torch.Tensor:
84
+ """
85
+ Convert input to tensor if needed and move to device
86
+
87
+ Args:
88
+ input_data: Input data (tensor or numpy array)
89
+ device: Target device
90
+
91
+ Returns:
92
+ torch.Tensor: Converted tensor
93
+ """
94
+ if isinstance(input_data, np.ndarray):
95
+ tensor = torch.from_numpy(input_data).float()
96
+ elif isinstance(input_data, torch.Tensor):
97
+ tensor = input_data.float()
98
+ else:
99
+ raise TypeError(f"Unsupported input type: {type(input_data)}")
100
+
101
+ if device is not None:
102
+ tensor = tensor.to(device)
103
+
104
+ return tensor
105
+
106
+
107
+ def normalize_tensor(tensor: torch.Tensor, target_range: Tuple[float, float] = (0.0, 1.0)) -> torch.Tensor:
108
+ """
109
+ Normalize tensor to target range
110
+
111
+ Args:
112
+ tensor: Input tensor
113
+ target_range: Target (min, max) range
114
+
115
+ Returns:
116
+ torch.Tensor: Normalized tensor
117
+ """
118
+ if not isinstance(tensor, torch.Tensor):
119
+ raise TypeError(f"Expected torch.Tensor, got {type(tensor)}")
120
+
121
+ min_val, max_val = target_range
122
+
123
+ # Normalize to [0, 1] first
124
+ tensor_min = tensor.min()
125
+ tensor_max = tensor.max()
126
+
127
+ if tensor_max > tensor_min:
128
+ normalized = (tensor - tensor_min) / (tensor_max - tensor_min)
129
+ else:
130
+ normalized = tensor - tensor_min
131
+
132
+ # Scale to target range
133
+ scaled = normalized * (max_val - min_val) + min_val
134
+
135
+ return scaled
136
+
137
+
138
+ def resize_tensor(tensor: torch.Tensor,
139
+ size: Tuple[int, int],
140
+ mode: str = 'bilinear',
141
+ align_corners: bool = False) -> torch.Tensor:
142
+ """
143
+ Resize tensor while maintaining tensor format
144
+
145
+ Args:
146
+ tensor: Input tensor (C, H, W) or (B, C, H, W)
147
+ size: Target (height, width)
148
+ mode: Interpolation mode
149
+ align_corners: Align corners flag
150
+
151
+ Returns:
152
+ torch.Tensor: Resized tensor
153
+ """
154
+ if not isinstance(tensor, torch.Tensor):
155
+ raise TypeError(f"Expected torch.Tensor, got {type(tensor)}")
156
+
157
+ original_dims = tensor.ndim
158
+
159
+ # Add batch dimension if needed
160
+ if tensor.ndim == 3:
161
+ tensor = tensor.unsqueeze(0)
162
+
163
+ # Resize
164
+ resized = F.interpolate(tensor, size=size, mode=mode, align_corners=align_corners)
165
+
166
+ # Remove batch dimension if it was added
167
+ if original_dims == 3:
168
+ resized = resized.squeeze(0)
169
+
170
+ return resized
171
+
172
+
173
+ def safe_tensor_operation(func):
174
+ """
175
+ Decorator to ensure tensor operations receive tensor inputs
176
+ """
177
+ def wrapper(*args, **kwargs):
178
+ # Check all args are tensors
179
+ for i, arg in enumerate(args):
180
+ if hasattr(arg, 'shape') and not isinstance(arg, torch.Tensor):
181
+ raise TypeError(f"Argument {i} must be torch.Tensor, got {type(arg)}")
182
+
183
+ return func(*args, **kwargs)
184
+
185
+ return wrapper
186
+
187
+
188
+ @safe_tensor_operation
189
+ def tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray:
190
+ """
191
+ Safely convert tensor to numpy array
192
+
193
+ Args:
194
+ tensor: Input tensor
195
+
196
+ Returns:
197
+ np.ndarray: Numpy array
198
+ """
199
+ if tensor.requires_grad:
200
+ tensor = tensor.detach()
201
+
202
+ if tensor.is_cuda:
203
+ tensor = tensor.cpu()
204
+
205
+ return tensor.numpy()
206
+
207
+
208
+ def validate_tensor_shapes(*tensors: torch.Tensor, expected_dims: int = None) -> bool:
209
+ """
210
+ Validate tensor shapes are compatible
211
+
212
+ Args:
213
+ tensors: Input tensors to validate
214
+ expected_dims: Expected number of dimensions
215
+
216
+ Returns:
217
+ bool: True if valid
218
+ """
219
+ if not tensors:
220
+ return True
221
+
222
+ if expected_dims is not None:
223
+ for tensor in tensors:
224
+ if tensor.ndim != expected_dims:
225
+ raise ValueError(f"Expected {expected_dims}D tensor, got {tensor.ndim}D")
226
+
227
+ # Check spatial dimensions match (last 2 dims)
228
+ reference_shape = tensors[0].shape[-2:]
229
+ for tensor in tensors[1:]:
230
+ if tensor.shape[-2:] != reference_shape:
231
+ raise ValueError(f"Spatial dimensions mismatch: {reference_shape} vs {tensor.shape[-2:]}")
232
+
233
+ return True