Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -4,10 +4,30 @@ import cv2
|
|
| 4 |
import plotly.graph_objects as go
|
| 5 |
from plotly.subplots import make_subplots
|
| 6 |
import pandas as pd
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
# FFT processing functions
|
| 9 |
def apply_fft(image):
|
| 10 |
-
"""Apply FFT to each channel of the image and return shifted FFT channels."""
|
| 11 |
fft_channels = []
|
| 12 |
for channel in cv2.split(image):
|
| 13 |
fft = np.fft.fft2(channel)
|
|
@@ -16,7 +36,6 @@ def apply_fft(image):
|
|
| 16 |
return fft_channels
|
| 17 |
|
| 18 |
def filter_fft_percentage(fft_channels, percentage):
|
| 19 |
-
"""Filter FFT channels to keep top percentage of magnitudes."""
|
| 20 |
filtered_fft = []
|
| 21 |
for fft_data in fft_channels:
|
| 22 |
magnitude = np.abs(fft_data)
|
|
@@ -28,7 +47,6 @@ def filter_fft_percentage(fft_channels, percentage):
|
|
| 28 |
return filtered_fft
|
| 29 |
|
| 30 |
def inverse_fft(filtered_fft):
|
| 31 |
-
"""Reconstruct image from filtered FFT channels."""
|
| 32 |
reconstructed_channels = []
|
| 33 |
for fft_data in filtered_fft:
|
| 34 |
fft_ishift = np.fft.ifftshift(fft_data)
|
|
@@ -37,8 +55,22 @@ def inverse_fft(filtered_fft):
|
|
| 37 |
reconstructed_channels.append(img_normalized.astype(np.uint8))
|
| 38 |
return cv2.merge(reconstructed_channels)
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
def create_3d_plot(fft_channels, downsample_factor=1):
|
| 41 |
-
"""Create interactive 3D surface plots using Plotly."""
|
| 42 |
fig = make_subplots(
|
| 43 |
rows=3, cols=2,
|
| 44 |
specs=[[{'type': 'scene'}, {'type': 'scene'}],
|
|
@@ -51,33 +83,26 @@ def create_3d_plot(fft_channels, downsample_factor=1):
|
|
| 51 |
)
|
| 52 |
)
|
| 53 |
|
| 54 |
-
channel_names = ['Blue', 'Green', 'Red']
|
| 55 |
-
|
| 56 |
for i, fft_data in enumerate(fft_channels):
|
| 57 |
-
# Downsample data for better performance
|
| 58 |
fft_down = fft_data[::downsample_factor, ::downsample_factor]
|
| 59 |
magnitude = np.abs(fft_down)
|
| 60 |
phase = np.angle(fft_down)
|
| 61 |
|
| 62 |
-
# Create grid coordinates
|
| 63 |
rows, cols = magnitude.shape
|
| 64 |
x = np.linspace(-cols//2, cols//2, cols)
|
| 65 |
y = np.linspace(-rows//2, rows//2, rows)
|
| 66 |
X, Y = np.meshgrid(x, y)
|
| 67 |
|
| 68 |
-
# Magnitude plot
|
| 69 |
fig.add_trace(
|
| 70 |
go.Surface(x=X, y=Y, z=magnitude, colorscale='Viridis', showscale=False),
|
| 71 |
row=i+1, col=1
|
| 72 |
)
|
| 73 |
|
| 74 |
-
# Phase plot
|
| 75 |
fig.add_trace(
|
| 76 |
go.Surface(x=X, y=Y, z=phase, colorscale='Inferno', showscale=False),
|
| 77 |
row=i+1, col=2
|
| 78 |
)
|
| 79 |
|
| 80 |
-
# Update layout for better visualization
|
| 81 |
fig.update_layout(
|
| 82 |
height=1500,
|
| 83 |
width=1200,
|
|
@@ -93,80 +118,104 @@ def create_3d_plot(fft_channels, downsample_factor=1):
|
|
| 93 |
|
| 94 |
# Streamlit UI
|
| 95 |
st.set_page_config(layout="wide")
|
| 96 |
-
st.title("Interactive Frequency Domain Analysis")
|
| 97 |
-
|
| 98 |
-
#
|
| 99 |
-
|
| 100 |
-
st.
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
|
|
|
|
|
|
|
|
|
| 106 |
uploaded_file = st.file_uploader("Upload an image", type=['png', 'jpg', 'jpeg'])
|
| 107 |
|
| 108 |
if uploaded_file is not None:
|
| 109 |
-
# Read and display original image
|
| 110 |
file_bytes = np.frombuffer(uploaded_file.getvalue(), np.uint8)
|
| 111 |
image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
|
| 112 |
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 113 |
st.image(image_rgb, caption="Original Image", use_column_width=True)
|
| 114 |
|
| 115 |
-
#
|
| 116 |
-
if
|
| 117 |
st.session_state.fft_channels = apply_fft(image)
|
| 118 |
|
| 119 |
-
#
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
)
|
| 126 |
-
submit_button = st.form_submit_button(label="Apply Filter")
|
| 127 |
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
filtered_fft = filter_fft_percentage(st.session_state.fft_channels, percentage)
|
| 131 |
-
reconstructed = inverse_fft(filtered_fft)
|
| 132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
st.image(reconstructed_rgb, caption="Reconstructed Image", use_column_width=True)
|
| 134 |
|
| 135 |
-
#
|
| 136 |
st.subheader("Frequency Data of Each Channel")
|
| 137 |
-
fft_data_dict = {}
|
| 138 |
for i, channel_name in enumerate(['Blue', 'Green', 'Red']):
|
| 139 |
-
magnitude = np.abs(st.session_state.fft_channels[i])
|
| 140 |
-
phase = np.angle(st.session_state.fft_channels[i])
|
| 141 |
-
fft_data_dict[channel_name] = {'Magnitude': magnitude, 'Phase': phase}
|
| 142 |
-
|
| 143 |
-
# Create DataFrames for each channel's FFT data
|
| 144 |
-
for channel_name, data in fft_data_dict.items():
|
| 145 |
st.write(f"### {channel_name} Channel FFT Data")
|
| 146 |
-
magnitude_df = pd.DataFrame(
|
| 147 |
-
phase_df = pd.DataFrame(
|
| 148 |
st.write("#### Magnitude Data:")
|
| 149 |
-
st.dataframe(magnitude_df.head(10))
|
| 150 |
st.write("#### Phase Data:")
|
| 151 |
-
st.dataframe(phase_df.head(10))
|
| 152 |
-
|
| 153 |
-
# Download button for reconstructed image
|
| 154 |
-
_, encoded_img = cv2.imencode('.png', reconstructed)
|
| 155 |
-
st.download_button(
|
| 156 |
-
"Download Reconstructed Image",
|
| 157 |
-
encoded_img.tobytes(),
|
| 158 |
-
"reconstructed.png",
|
| 159 |
-
"image/png"
|
| 160 |
-
)
|
| 161 |
|
| 162 |
-
# 3D
|
| 163 |
st.subheader("3D Frequency Components Visualization")
|
| 164 |
downsample = st.slider(
|
| 165 |
"Downsampling factor for 3D plots:",
|
| 166 |
-
|
| 167 |
-
help="Controls the resolution of the 3D surface plots.
|
| 168 |
)
|
| 169 |
-
|
| 170 |
-
# Generate and display 3D plots
|
| 171 |
-
fig = create_3d_plot(filtered_fft, downsample)
|
| 172 |
st.plotly_chart(fig, use_container_width=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
import plotly.graph_objects as go
|
| 5 |
from plotly.subplots import make_subplots
|
| 6 |
import pandas as pd
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
# Dummy CNN Model
|
| 12 |
+
class SimpleCNN(nn.Module):
|
| 13 |
+
def __init__(self):
|
| 14 |
+
super(SimpleCNN, self).__init__()
|
| 15 |
+
self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
|
| 16 |
+
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
|
| 17 |
+
self.fc1 = nn.Linear(32 * 8 * 8, 128)
|
| 18 |
+
self.fc2 = nn.Linear(128, 10)
|
| 19 |
+
|
| 20 |
+
def forward(self, x):
|
| 21 |
+
x1 = F.relu(self.conv1(x)) # First conv layer activation
|
| 22 |
+
x2 = F.relu(self.conv2(x1))
|
| 23 |
+
x3 = F.adaptive_avg_pool2d(x2, (8, 8))
|
| 24 |
+
x4 = x3.view(x3.size(0), -1)
|
| 25 |
+
x5 = F.relu(self.fc1(x4))
|
| 26 |
+
x6 = self.fc2(x5)
|
| 27 |
+
return x6, x1 # Return both output and first layer activations
|
| 28 |
|
| 29 |
# FFT processing functions
|
| 30 |
def apply_fft(image):
|
|
|
|
| 31 |
fft_channels = []
|
| 32 |
for channel in cv2.split(image):
|
| 33 |
fft = np.fft.fft2(channel)
|
|
|
|
| 36 |
return fft_channels
|
| 37 |
|
| 38 |
def filter_fft_percentage(fft_channels, percentage):
|
|
|
|
| 39 |
filtered_fft = []
|
| 40 |
for fft_data in fft_channels:
|
| 41 |
magnitude = np.abs(fft_data)
|
|
|
|
| 47 |
return filtered_fft
|
| 48 |
|
| 49 |
def inverse_fft(filtered_fft):
|
|
|
|
| 50 |
reconstructed_channels = []
|
| 51 |
for fft_data in filtered_fft:
|
| 52 |
fft_ishift = np.fft.ifftshift(fft_data)
|
|
|
|
| 55 |
reconstructed_channels.append(img_normalized.astype(np.uint8))
|
| 56 |
return cv2.merge(reconstructed_channels)
|
| 57 |
|
| 58 |
+
# CNN Pass Visualization
|
| 59 |
+
def pass_to_cnn(fft_image):
|
| 60 |
+
model = SimpleCNN()
|
| 61 |
+
magnitude_tensor = torch.tensor(np.abs(fft_image), dtype=torch.float32).unsqueeze(0).unsqueeze(0)
|
| 62 |
+
|
| 63 |
+
with torch.no_grad():
|
| 64 |
+
output, activations = model(magnitude_tensor)
|
| 65 |
+
|
| 66 |
+
# Ensure activations have the correct shape [batch_size, channels, height, width]
|
| 67 |
+
if len(activations.shape) == 3:
|
| 68 |
+
activations = activations.unsqueeze(0) # Add batch dimension if missing
|
| 69 |
+
|
| 70 |
+
return activations, magnitude_tensor
|
| 71 |
+
|
| 72 |
+
# 3D plotting function
|
| 73 |
def create_3d_plot(fft_channels, downsample_factor=1):
|
|
|
|
| 74 |
fig = make_subplots(
|
| 75 |
rows=3, cols=2,
|
| 76 |
specs=[[{'type': 'scene'}, {'type': 'scene'}],
|
|
|
|
| 83 |
)
|
| 84 |
)
|
| 85 |
|
|
|
|
|
|
|
| 86 |
for i, fft_data in enumerate(fft_channels):
|
|
|
|
| 87 |
fft_down = fft_data[::downsample_factor, ::downsample_factor]
|
| 88 |
magnitude = np.abs(fft_down)
|
| 89 |
phase = np.angle(fft_down)
|
| 90 |
|
|
|
|
| 91 |
rows, cols = magnitude.shape
|
| 92 |
x = np.linspace(-cols//2, cols//2, cols)
|
| 93 |
y = np.linspace(-rows//2, rows//2, rows)
|
| 94 |
X, Y = np.meshgrid(x, y)
|
| 95 |
|
|
|
|
| 96 |
fig.add_trace(
|
| 97 |
go.Surface(x=X, y=Y, z=magnitude, colorscale='Viridis', showscale=False),
|
| 98 |
row=i+1, col=1
|
| 99 |
)
|
| 100 |
|
|
|
|
| 101 |
fig.add_trace(
|
| 102 |
go.Surface(x=X, y=Y, z=phase, colorscale='Inferno', showscale=False),
|
| 103 |
row=i+1, col=2
|
| 104 |
)
|
| 105 |
|
|
|
|
| 106 |
fig.update_layout(
|
| 107 |
height=1500,
|
| 108 |
width=1200,
|
|
|
|
| 118 |
|
| 119 |
# Streamlit UI
|
| 120 |
st.set_page_config(layout="wide")
|
| 121 |
+
st.title("Interactive Frequency Domain Analysis with CNN")
|
| 122 |
+
|
| 123 |
+
# Initialize session state
|
| 124 |
+
if 'fft_channels' not in st.session_state:
|
| 125 |
+
st.session_state.fft_channels = None
|
| 126 |
+
if 'filtered_fft' not in st.session_state:
|
| 127 |
+
st.session_state.filtered_fft = None
|
| 128 |
+
if 'reconstructed' not in st.session_state:
|
| 129 |
+
st.session_state.reconstructed = None
|
| 130 |
+
if 'show_cnn' not in st.session_state:
|
| 131 |
+
st.session_state.show_cnn = False
|
| 132 |
+
|
| 133 |
+
# Upload image
|
| 134 |
uploaded_file = st.file_uploader("Upload an image", type=['png', 'jpg', 'jpeg'])
|
| 135 |
|
| 136 |
if uploaded_file is not None:
|
|
|
|
| 137 |
file_bytes = np.frombuffer(uploaded_file.getvalue(), np.uint8)
|
| 138 |
image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
|
| 139 |
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 140 |
st.image(image_rgb, caption="Original Image", use_column_width=True)
|
| 141 |
|
| 142 |
+
# Apply FFT and store in session state
|
| 143 |
+
if st.session_state.fft_channels is None:
|
| 144 |
st.session_state.fft_channels = apply_fft(image)
|
| 145 |
|
| 146 |
+
# Frequency percentage slider
|
| 147 |
+
percentage = st.slider(
|
| 148 |
+
"Percentage of frequencies to retain:",
|
| 149 |
+
0.1, 100.0, 10.0, 0.1,
|
| 150 |
+
help="Adjust the slider to select what portion of frequency components to keep."
|
| 151 |
+
)
|
|
|
|
|
|
|
| 152 |
|
| 153 |
+
# Apply FFT filter
|
| 154 |
+
if st.button("Apply Filter"):
|
| 155 |
+
st.session_state.filtered_fft = filter_fft_percentage(st.session_state.fft_channels, percentage)
|
| 156 |
+
st.session_state.reconstructed = inverse_fft(st.session_state.filtered_fft)
|
| 157 |
+
st.session_state.show_cnn = False # Reset CNN visualization
|
| 158 |
+
|
| 159 |
+
# Display reconstructed image and FFT data
|
| 160 |
+
if st.session_state.reconstructed is not None:
|
| 161 |
+
reconstructed_rgb = cv2.cvtColor(st.session_state.reconstructed, cv2.COLOR_BGR2RGB)
|
| 162 |
st.image(reconstructed_rgb, caption="Reconstructed Image", use_column_width=True)
|
| 163 |
|
| 164 |
+
# FFT Data Tables
|
| 165 |
st.subheader("Frequency Data of Each Channel")
|
|
|
|
| 166 |
for i, channel_name in enumerate(['Blue', 'Green', 'Red']):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
st.write(f"### {channel_name} Channel FFT Data")
|
| 168 |
+
magnitude_df = pd.DataFrame(np.abs(st.session_state.filtered_fft[i]))
|
| 169 |
+
phase_df = pd.DataFrame(np.angle(st.session_state.filtered_fft[i]))
|
| 170 |
st.write("#### Magnitude Data:")
|
| 171 |
+
st.dataframe(magnitude_df.head(10))
|
| 172 |
st.write("#### Phase Data:")
|
| 173 |
+
st.dataframe(phase_df.head(10))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
|
| 175 |
+
# 3D Visualization
|
| 176 |
st.subheader("3D Frequency Components Visualization")
|
| 177 |
downsample = st.slider(
|
| 178 |
"Downsampling factor for 3D plots:",
|
| 179 |
+
1, 20, 5,
|
| 180 |
+
help="Controls the resolution of the 3D surface plots."
|
| 181 |
)
|
| 182 |
+
fig = create_3d_plot(st.session_state.filtered_fft, downsample)
|
|
|
|
|
|
|
| 183 |
st.plotly_chart(fig, use_container_width=True)
|
| 184 |
+
|
| 185 |
+
# CNN Visualization Section
|
| 186 |
+
if st.button("Pass to CNN"):
|
| 187 |
+
st.session_state.show_cnn = True
|
| 188 |
+
|
| 189 |
+
if st.session_state.show_cnn:
|
| 190 |
+
st.subheader("CNN Processing Visualization")
|
| 191 |
+
activations, magnitude_tensor = pass_to_cnn(st.session_state.filtered_fft[0])
|
| 192 |
+
|
| 193 |
+
# Display input tensor
|
| 194 |
+
st.write("### Input Magnitude Tensor:")
|
| 195 |
+
st.image(magnitude_tensor.squeeze().numpy(),
|
| 196 |
+
caption="Magnitude Tensor",
|
| 197 |
+
use_column_width=True,
|
| 198 |
+
clamp=True)
|
| 199 |
+
|
| 200 |
+
# Display activations
|
| 201 |
+
st.write("### First Convolution Layer Activations")
|
| 202 |
+
activation = activations.detach().numpy()
|
| 203 |
+
|
| 204 |
+
# Check the shape of the activation tensor
|
| 205 |
+
if len(activation.shape) == 4: # [batch_size, channels, height, width]
|
| 206 |
+
for i in range(activation.shape[1]): # Loop through channels
|
| 207 |
+
act_img = activation[0, i, :, :] # Select the first batch and current channel
|
| 208 |
+
act_img_normalized = (act_img - act_img.min()) / (act_img.max() - act_img.min()) # Normalize
|
| 209 |
+
|
| 210 |
+
# Display activation map
|
| 211 |
+
st.write(f"#### Activation Channel {i+1}")
|
| 212 |
+
st.image(act_img_normalized,
|
| 213 |
+
caption=f"Activation Channel {i+1}",
|
| 214 |
+
use_column_width=True)
|
| 215 |
+
|
| 216 |
+
# Display activation values in a table
|
| 217 |
+
st.write("##### Activation Values:")
|
| 218 |
+
activation_df = pd.DataFrame(act_img)
|
| 219 |
+
st.dataframe(activation_df)
|
| 220 |
+
else:
|
| 221 |
+
st.error(f"Unexpected activation shape: {activation.shape}. Expected 4 dimensions.")
|