euiia commited on
Commit
7974f2d
·
verified ·
1 Parent(s): 98772e2

Create tensor_utils.py

Browse files
Files changed (1) hide show
  1. tools/tensor_utils.py +72 -0
tools/tensor_utils.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # tools/tensor_utils.py
2
+ #
3
+ # Copyright (C) August 4, 2025 Carlos Rodrigues dos Santos
4
+ #
5
+ # Version: 1.0.0
6
+ #
7
+ # This module provides utility functions for tensor manipulation, specifically for
8
+ # image and video processing tasks. The functions here, such as wavelet reconstruction,
9
+ # are internalized within the ADUC-SDR framework to ensure stability and reduce
10
+ # reliance on specific external library structures.
11
+ #
12
+ # The wavelet_reconstruction code is adapted from the SeedVR project.
13
+
14
+ import torch
15
+ from torch import Tensor
16
+ from torch.nn import functional as F
17
+
18
+ def wavelet_blur(image: Tensor, radius: int) -> Tensor:
19
+ """
20
+ Apply wavelet blur to the input tensor.
21
+ """
22
+ # convolution kernel
23
+ kernel_vals = [
24
+ [0.0625, 0.125, 0.0625],
25
+ [0.125, 0.25, 0.125],
26
+ [0.0625, 0.125, 0.0625],
27
+ ]
28
+ kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
29
+ # add channel dimensions to the kernel to make it a 4D tensor
30
+ kernel = kernel[None, None]
31
+ # repeat the kernel across all input channels
32
+ kernel = kernel.repeat(image.shape[1], 1, 1, 1) # Match input channels
33
+ image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
34
+ # apply convolution
35
+ output = F.conv2d(image, kernel, groups=image.shape[1], dilation=radius)
36
+ return output
37
+
38
+ def wavelet_decomposition(image: Tensor, levels=5) -> Tuple[Tensor, Tensor]:
39
+ """
40
+ Apply wavelet decomposition to the input tensor.
41
+ This function returns both the high frequency and low frequency components.
42
+ """
43
+ high_freq = torch.zeros_like(image)
44
+ low_freq = image
45
+ for i in range(levels):
46
+ radius = 2 ** i
47
+ blurred = wavelet_blur(low_freq, radius)
48
+ high_freq += (low_freq - blurred)
49
+ low_freq = blurred
50
+
51
+ return high_freq, low_freq
52
+
53
+ def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor:
54
+ """
55
+ Applies wavelet decomposition to transfer the color/style (low-frequency components)
56
+ from a style feature to the details (high-frequency components) of a content feature.
57
+
58
+ Args:
59
+ content_feat (Tensor): The tensor containing the structural details.
60
+ style_feat (Tensor): The tensor containing the desired color and lighting style.
61
+
62
+ Returns:
63
+ Tensor: The reconstructed tensor with content details and style colors.
64
+ """
65
+ # calculate the wavelet decomposition of the content feature
66
+ content_high_freq, _ = wavelet_decomposition(content_feat)
67
+
68
+ # calculate the wavelet decomposition of the style feature
69
+ _, style_low_freq = wavelet_decomposition(style_feat)
70
+
71
+ # reconstruct the content feature with the style's low frequency (color/lighting)
72
+ return content_high_freq + style_low_freq