Anna / app.py
JohnChiu's picture
add elegant algorithm
6fa13e0
import gradio as gr
import numpy as np
import plotly.graph_objs as go
from scipy.ndimage import convolve
from gradio_imageslider import ImageSlider
import cv2
import os
# def readRAW(path):
# arr = np.fromfile(path, dtype=np.int16).reshape(96,240,256)
# # 将最后一维重塑为 (-1, 2),其中 -1 自动计算为 128
# reshaped = arr.reshape(*arr.shape[:-1], -1, 2)
# # 交换每一对中的两个元素
# swapped = reshaped[..., :, ::-1]
# # 恢复原始形状
# histogram_data = swapped.reshape(arr.shape)
# # 定义映射顺序:对每组8行进行调换
# mapping = [0, 4, 1, 5, 2, 6, 3, 7]
# # 每组包含的行数
# group_size = 8
# num_groups = 12 # 96/8
# # 创建一个用于存储结果的数组(也可以原地修改)
# output = np.empty_like(histogram_data)
# # 对每个 group 分别进行行重排
# for g in range(num_groups):
# start = g * group_size
# end = start + group_size
# output[start:end,:,:] = histogram_data[start:end,:,:][mapping,:,:]
# return output
def mipi_raw10_to_raw8_scaled(raw10_data):
raw10_data = np.frombuffer(raw10_data, dtype=np.uint8)
n_blocks = len(raw10_data) // 5
raw10_data = raw10_data[:n_blocks * 5].reshape(-1, 5)
p0 = (raw10_data[:, 0].astype(np.uint16) << 2) | ((raw10_data[:, 4] >> 0) & 0x03)
p1 = (raw10_data[:, 1].astype(np.uint16) << 2) | ((raw10_data[:, 4] >> 2) & 0x03)
p2 = (raw10_data[:, 2].astype(np.uint16) << 2) | ((raw10_data[:, 4] >> 4) & 0x03)
p3 = (raw10_data[:, 3].astype(np.uint16) << 2) | ((raw10_data[:, 4] >> 6) & 0x03)
raw8_data = np.empty((n_blocks * 4 * 2,), dtype=np.uint8)
raw8_data[0::8] = p0 & 0xFF
raw8_data[1::8] = p0 >> 8
raw8_data[2::8] = p1 & 0xFF
raw8_data[3::8] = p1 >> 8
raw8_data[4::8] = p2 & 0xFF
raw8_data[5::8] = p2 >> 8
raw8_data[6::8] = p3 & 0xFF
raw8_data[7::8] = p3 >> 8
return raw8_data.tobytes()
def readRAW(path):
filesize = os.path.getsize(path)
with open(path, "rb") as f:
raw_data = f.read()
# Case 1: 如果是 MIPI RAW10 格式,大小为 7,372,800 字节
if filesize == 7372800:
raw_data = mipi_raw10_to_raw8_scaled(raw_data)
# 转换为 int16 并 reshape
arr = np.frombuffer(raw_data, dtype=np.int16).reshape(96, 240, 256)
# Byte Swap: [x,y,256] → [x,y,128,2] → swap last dim → [x,y,256]
reshaped = arr.reshape(*arr.shape[:-1], -1, 2)
swapped = reshaped[..., ::-1]
histogram_data = swapped.reshape(arr.shape)
# Line remapping (每组8行:0,4,1,5,...)
mapping = [0, 4, 1, 5, 2, 6, 3, 7]
group_size = 8
num_groups = 12 # 96 / 8
output = np.empty_like(histogram_data)
for g in range(num_groups):
start = g * group_size
end = start + group_size
output[start:end, :, :] = histogram_data[start:end, :, :][mapping, :, :]
return output
def load_bin(file, threshold=3):
raw_hist = readRAW(file.name)
multishot = (raw_hist[..., 254] * 1024 + raw_hist[..., 255]).astype(np.float32)
normalize_data = 12000 / multishot
nor_hist = raw_hist * normalize_data[..., np.newaxis]
img = np.sum(nor_hist[:, :, :-2], axis=2)
norm_img = (img - img.min()) / (img.max() + 1e-8)
img_uint8 = (norm_img * 255).astype(np.uint8)
img_zoomed = np.repeat(np.repeat(img_uint8, 4, axis=0), 4, axis=1)
depth_slider_imgs = plot_depth(nor_hist,threshold) # 👈 直接在这里计算
return img_zoomed, raw_hist, nor_hist, depth_slider_imgs
def plot_pixel_histogram(evt: gr.SelectData, raw_hist, nor_hist):
# print("evt:", evt)
x, y = evt.index # Gradio SelectData 对象
x = x // 4
y = y // 4
raw_values = raw_hist[y, x, :]
nor_values = nor_hist[y, x, :]
fig = go.Figure()
fig.add_trace(go.Scatter(y=raw_values, mode="lines+markers"))
fig.update_layout(
title=f"Pixel ({x}, {y}) 在所有 {raw_values.shape[0]} 帧的强度变化",
xaxis_title="帧索引 (T)",
yaxis_title="强度值",
)
return fig
def to_uint8_image(arr):
norm = (arr) / (np.max(arr) + 1e-8)
return (norm * 255).astype(np.uint8)
def plot_depth(norm_hist, threshold):
tof = np.argmax(norm_hist[...,:-5], axis=2)
img_tof = to_uint8_image(tof)
noise = np.median(norm_hist[...,:8],axis=2)
noise_th = noise + 4 * np.sqrt(noise)
norm_hist_sub_noise = norm_hist - noise_th[...,np.newaxis]
# norm_hist = norm_hist
norm_hist_sub_noise[norm_hist_sub_noise<0]=0
norm_hist_pool = norm_hist_sub_noise[::10,::10,:]
print(norm_hist_pool.shape)
lst_scatter_th = []
for idx in range(0,256):
map = norm_hist_pool[...,idx]
ratio = 1/np.max(map) * 255
map_ratio = map * ratio
_, otsu_thresh = cv2.threshold(map.astype(np.uint8), 0, 255, cv2.THRESH_BINARY + cv2.THRESH_TRIANGLE)
# _, otsu_thresh = cv2.threshold(map_ratio.astype(np.uint8), 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
lst_scatter_th.append(_/ratio)
de_scatter = norm_hist_sub_noise - np.array(lst_scatter_th)[np.newaxis, np.newaxis, :]
de_scatter[de_scatter<0]=0
# de_scatter = norm_hist
tof = np.argmax(de_scatter[...,:-5], axis=2)
peak = np.max(de_scatter[...,:-5], axis=2)
r = 4
neighbor_score = 16
snr_th = 0.1
for y in range(norm_hist.shape[0]):
for x in range(norm_hist.shape[1]):
t = tof[y,x]
shift = np.argmax(norm_hist[y,x,max(0,t-r):min(255,t+r+1)]) -r
t = t + shift
cubic_S = de_scatter[max(0,y-1):min(norm_hist.shape[0],y+2), max(0,x-1):min(norm_hist.shape[1],x+2), max(0,t-1):min(255,t+2)]
cubic_SN = norm_hist[max(0,y-1):min(norm_hist.shape[0],y+2), max(0,x-1):min(norm_hist.shape[1],x+2), max(0,t-1):min(255,t+2)]
snr = cubic_S/np.sqrt(cubic_SN + 1e-6)
# print(shift)
mask = snr> snr_th
if abs(shift) <= r and np.sum(mask) > neighbor_score:
tof[y,x] = t
else:
tof[y,x] = 0
def get_value_at_depth_index(array_3d, depth_index):
return array_3d[np.arange(array_3d.shape[0])[:, None], np.arange(array_3d.shape[1]), depth_index]
C3 = get_value_at_depth_index(norm_hist,tof-1)
C4 = get_value_at_depth_index(norm_hist,tof)
C5 = get_value_at_depth_index(norm_hist,tof+1)
shift_mat = (C5-C3)/(4.0 * C4 -2.0 * C3 - 2.0 * C5 + 1e-6)
mask = abs(shift_mat) < 1
shift_mat = shift_mat * mask
tof[tof<0]=0
img_filter = to_uint8_image(tof)
colored_tof = cv2.applyColorMap(img_tof, cv2.COLORMAP_VIRIDIS)[:, :, ::-1]
colored_tof_filter = cv2.applyColorMap(img_filter, cv2.COLORMAP_VIRIDIS)[:, :, ::-1]
return [colored_tof, colored_tof_filter]
# return [img_ref, img_ref_filter]
with gr.Blocks() as demo:
gr.Markdown("## 上传 96×240×256 int16 `.bin/.raw` 文件,点击图像像素查看该像素的 256 帧直方图")
file_input = gr.File(label="上传 .raw/.bin 文件", file_types=[".raw", ".bin"])
image_display = gr.Image(interactive=True, label="点击像素显示强度曲线")
histogram = gr.Plot(label="像素强度曲线")
depth_image_slider = ImageSlider(label="Filter Depth Map with Slider View", elem_id='img-display-output', position=0.5)
threshold_slider = gr.Slider(1, 30, value=3, step=1, label="Mask 阈值设定 (ref > x)")
raw_hist = gr.State()
nor_hist = gr.State()
# 单一入口统一触发
file_input.change(
load_bin,
inputs=[file_input,threshold_slider],
outputs=[image_display, raw_hist, nor_hist, depth_image_slider]
)
image_display.select(
plot_pixel_histogram,
inputs=[raw_hist, nor_hist],
outputs=histogram
)
threshold_slider.change(
load_bin,
inputs=[file_input,threshold_slider],
outputs=[image_display, raw_hist, nor_hist, depth_image_slider])
demo.launch()