Anna / app_nn.py
JohnChiu's picture
add elegant algorithm
6fa13e0
import gradio as gr
import numpy as np
import plotly.graph_objs as go
from scipy.ndimage import convolve
from gradio_imageslider import ImageSlider
import cv2
import os
import ToF_utils
import onnxruntime as ort
from scipy.signal import find_peaks
import matplotlib
from PIL import Image
# def mipi_raw10_to_raw8_scaled(raw10_data):
# raw10_data = np.frombuffer(raw10_data, dtype=np.uint8)
# n_blocks = len(raw10_data) // 5
# raw10_data = raw10_data[:n_blocks * 5].reshape(-1, 5)
# p0 = (raw10_data[:, 0].astype(np.uint16) << 2) | ((raw10_data[:, 4] >> 0) & 0x03)
# p1 = (raw10_data[:, 1].astype(np.uint16) << 2) | ((raw10_data[:, 4] >> 2) & 0x03)
# p2 = (raw10_data[:, 2].astype(np.uint16) << 2) | ((raw10_data[:, 4] >> 4) & 0x03)
# p3 = (raw10_data[:, 3].astype(np.uint16) << 2) | ((raw10_data[:, 4] >> 6) & 0x03)
# raw8_data = np.empty((n_blocks * 4 * 2,), dtype=np.uint8)
# raw8_data[0::8] = p0 & 0xFF
# raw8_data[1::8] = p0 >> 8
# raw8_data[2::8] = p1 & 0xFF
# raw8_data[3::8] = p1 >> 8
# raw8_data[4::8] = p2 & 0xFF
# raw8_data[5::8] = p2 >> 8
# raw8_data[6::8] = p3 & 0xFF
# raw8_data[7::8] = p3 >> 8
# return raw8_data.tobytes()
# def readRAW(path):
# filesize = os.path.getsize(path)
# with open(path, "rb") as f:
# raw_data = f.read()
# # Case 1: 如果是 MIPI RAW10 格式,大小为 7,372,800 字节
# if filesize == 7372800:
# raw_data = ToF_utils.mipi_raw10_to_raw8_scaled(raw_data)
# # 转换为 int16 并 reshape
# arr = np.frombuffer(raw_data, dtype=np.int16).reshape(96, 240, 256)
# # Byte Swap: [x,y,256] → [x,y,128,2] → swap last dim → [x,y,256]
# reshaped = arr.reshape(*arr.shape[:-1], -1, 2)
# swapped = reshaped[..., ::-1]
# histogram_data = swapped.reshape(arr.shape)
# # Line remapping (每组8行:0,4,1,5,...)
# mapping = [0, 4, 1, 5, 2, 6, 3, 7]
# group_size = 8
# num_groups = 12 # 96 / 8
# output = np.empty_like(histogram_data)
# for g in range(num_groups):
# start = g * group_size
# end = start + group_size
# output[start:end, :, :] = histogram_data[start:end, :, :][mapping, :, :]
# return output
# session = ort.InferenceSession(r"D:\GitHub\dtof_depth_estimation_npu\unet1d_rnn.onnx", providers=['CPUExecutionProvider'])
# session = ort.InferenceSession(r"D:\GitHub\dtof_depth_estimation_npu\unet1d_tf.onnx", providers=['CPUExecutionProvider'])
cmap = matplotlib.colormaps.get_cmap('rainbow_r')
# def load_bin(file,threshold_slider0,threshold_slider1):
zoom_scale = 6
def load_bin(file):
raw_hist = ToF_utils.readRAW(file.name)
multishot = (raw_hist[..., 254] * 1024 + raw_hist[..., 255]).astype(np.float32)
normalize_data = 12000 / multishot
nor_hist = raw_hist * normalize_data[..., np.newaxis]
bin_hist = ToF_utils.binning_2x2_stride2(raw_hist) * 0.25 / 1023
print('bin hist shape',bin_hist.shape,' ',np.max(bin_hist,axis=(2)).shape)
# bin_hist = bin_hist / np.max(bin_hist[...,:-2],axis=(2))[...,np.newaxis]
multishot_data = ((bin_hist[...,254]) * 1024) + ((bin_hist[...,255]))
# lam = np.max(bin_hist[...,200:253],axis=2)
# max_theory = lam + 5 * np.sqrt(lam)
alpha = 0
dcr = 0
lam = np.median(bin_hist[...,0:253],axis=2)
max_theory = lam + alpha * np.sqrt(lam) + dcr
# bin_hist_rm_noise = bin_hist - max_theory[...,np.newaxis]
bin_hist_rm_noise = bin_hist
bin_hist_rm_noise[bin_hist_rm_noise<0] = 0
bin_hist_rm_noise[...,-2:]=0
# histograms_ma = ToF_utils.ma_vectorized(bin_hist_rm_noise,kernel=[1,1,1])
histograms_ma = ToF_utils.ma_vectorized(bin_hist_rm_noise,kernel=[1])
histograms_ma[...,-5:] =histograms_ma[...,-5,np.newaxis]
histograms_ma[histograms_ma<0] = 0
tc_range = 180
bin_m = np.arange(0,tc_range,1)/tc_range/tc_range
histograms_ma_nrom = histograms_ma[...,:tc_range]/multishot_data[...,np.newaxis] * 20e3 *bin_m*bin_m
tc = np.sum(histograms_ma_nrom,axis=2)
norm_img = (tc - tc.min()) / (tc.max() + 1e-8)
img_uint8 = (norm_img * 255).astype(np.uint8)
img_zoomed = np.repeat(np.repeat(img_uint8, zoom_scale, axis=0), zoom_scale, axis=1)
return img_zoomed, bin_hist ,histograms_ma
# input_data = histograms_ma.reshape(48*120,1,256).astype(np.float32)
# outputs = session.run(None, {"input": input_data})
# result = outputs[0].squeeze(1)
# print(result.shape)
# result[result<0.1] = 0
# echo_num = 2
# first_two_peaks = np.full((5760, echo_num), -1, dtype=int)
# for i in range(5760):
# peaks, _ = find_peaks(result[i]) # local maxima indices
# # Take first two from left
# first_two_peaks[i, :len(peaks[:echo_num])] = peaks[:echo_num]
# ref_set = np.zeros((48,120,echo_num))
# value_set = np.zeros((48,120,echo_num))
# first_two_peaks = first_two_peaks.reshape(48,120,echo_num)
# print(first_two_peaks.shape) # (5760, 2)
# first_two_peaks[first_two_peaks>180] = 0
# for i in range(first_two_peaks.shape[2]):
# rows = np.arange(48)[:, None] # shape (H,1)
# cols = np.arange(120)[None, :] # shape (1,W)
# values = input_data.reshape(48,120,256)[rows, cols, first_two_peaks[...,i]] # shape (H, W)
# values = values * np.power(2,13)/multishot_data*256/48000
# value_set[...,i] = values
# tof = first_two_peaks[...,i] - 3
# tof[tof<0] = 1
# ref = tof * tof * values /1200 * 6
# ref_set[...,i] = ref
# frame1 = ref_set[...,0]
# frame2 = ref_set[...,1] * 1
# # 两个 mask 初始化为 0
# mask1 = np.zeros_like(frame1, dtype=np.uint8)
# mask2 = np.zeros_like(frame2, dtype=np.uint8)
# frame1[frame1<threshold_slider0] = 0
# frame2[frame2<threshold_slider1] = 0
# mask1[frame1 > 0] = 1
# mask2[(mask1 == 0) & (frame2 > 0)] = 1
# mask = np.stack([mask1,mask2], axis=-1) # 形状 (48, 120, 2)
# tof = first_two_peaks[...,0] * mask1 + first_two_peaks[...,1] * mask2
# peak = value_set[...,0] * mask1 + value_set[...,1] * mask2
# # # 计算类别索引:两通道为0→0;第一个通道1→1;第二个通道1→2
# class_mask = mask.argmax(axis=2) # 直接取最大通道
# # 但是需要保证 (0,0) 的情况变回 0
# class_mask[(mask.sum(axis=-1) == 0)] = -1 # 处理无效像素
# class_mask = class_mask + 1
# totalcount = np.sum(histograms_ma,axis=2)/multishot_data
# totalcount = totalcount / np.max(totalcount)
# tof0 = to_uint8_image( first_two_peaks[...,0] * mask1)
# tof1 = to_uint8_image( first_two_peaks[...,1] * mask2)
# tof0_zoomed = np.repeat(np.repeat(tof0, 4, axis=0), 4, axis=1)
# tof1_zoomed = np.repeat(np.repeat(tof1, 4, axis=0), 4, axis=1)
# tof0_colored_depth = (cmap(tof0_zoomed)[:, :, :3] * 255).astype(np.uint8)
# tof1_colored_depth = (cmap(tof1_zoomed)[:, :, :3] * 255).astype(np.uint8)
# return img_zoomed, raw_hist, nor_hist, tof0_colored_depth, tof1_colored_depth
def plot_pixel_histogram(evt: gr.SelectData, raw_hist,histograms_ma):
# print("evt:", evt)
x, y = evt.index # Gradio SelectData 对象
x = x // zoom_scale
y = y // zoom_scale
# print(raw_hist.shape)
raw_values = raw_hist[y, x, :]
input_data = histograms_ma.reshape(48*120,1,256).astype(np.float32)
h = None
session = ort.InferenceSession(r"D:\GitHub\dtof_depth_estimation_npu\unet1d_tf.onnx", providers=['CPUExecutionProvider'])
outputs = session.run(None, {"input": input_data})
result = outputs[0].squeeze(1)
print(result.shape)
ma_values = (histograms_ma[y, x, :]) * 1023
nn_prob = result[y*120+ x, :]*np.max(ma_values)
alpha = 1
lam = (np.mean(ma_values[240:254]) + np.mean(ma_values[0:10]))/2
max_theory = lam + alpha * np.sqrt(lam)
# ma_values[ma_values>10]=0
fig = go.Figure()
# fig.add_trace(go.Scatter(y=raw_values, mode="lines+markers",name="Raw Histogram"))
fig.add_trace(go.Scatter(y=ma_values, mode="lines+markers",name="Ma Histogram (rm noise)"))
fig.add_trace(go.Scatter(y=nn_prob, mode="lines+markers",name="NN prob"))
# 加一条水平线,比如 y=0.5
fig.add_hline(y=max_theory, line_dash="dash", line_color="red", annotation_text="Threshold", annotation_position="top left")
fig.update_layout(
title=f"Pixel ({x}, {y}) 在所有 {raw_values.shape[0]} 帧的强度变化",
xaxis_title="帧索引 (T)",
yaxis_title="强度值",
)
return fig
def to_uint8_image(arr):
norm = (arr) / (np.max(arr) + 1e-8)
return (norm * 255).astype(np.uint8)
def plot_depth(nor_hist, threshold):
kernel = np.ones((3,3))
output = np.zeros((96, 240, 254))
for i in range(254):
output[:, :, i] = convolve(nor_hist[:, :, i], kernel, mode='constant', cval=0)
modulate1 = np.arange(1, 181)
modulate = (modulate1 * modulate1) / (180 * 180)
arr = output[:, :, :180] - np.max(output[:, :, :10])
tc_bin = np.sum(arr, axis=(0,1))
max_id = np.argmax(tc_bin[:-2])
pad_head = np.ones(max_id - 4)
expand_kernel = np.arange(1, 13) * 0.01
pad_tail = np.ones(180 - len(pad_head) - len(expand_kernel))
expand_filter = np.concatenate([pad_head, expand_kernel, pad_tail])
print('np.max(arr,axis =(0,1,2))',np.max(arr,axis =(0,1,2)))
arr_expandfilter = arr * modulate[np.newaxis, np.newaxis, :] * expand_filter[np.newaxis, np.newaxis, :]
print('np.max(arr_expandfilter,axis =(0,1,2))',np.max(arr_expandfilter,axis =(0,1,2)))
tof = np.argmax(arr, axis=2)
tof_filter = np.argmax(arr_expandfilter, axis=2)
ref = np.max(arr, axis=2)
ref_filter = np.max(arr_expandfilter, axis=2)
img_ref = to_uint8_image(ref)
img_ref_filter = to_uint8_image(ref_filter)
mask = ref_filter > threshold
# 转uint8图像方便展示
img_tof = to_uint8_image(tof)
img_filter = to_uint8_image(tof_filter)
img_filter *= mask
colored_tof = cv2.applyColorMap(img_tof, cv2.COLORMAP_VIRIDIS)[:, :, ::-1]
colored_tof_filter = cv2.applyColorMap(img_filter, cv2.COLORMAP_VIRIDIS)[:, :, ::-1]
return [colored_tof, colored_tof_filter]
# return [img_ref, img_ref_filter]
def preview_editor(im):
# 返回图像编辑后的合成图像
return im["composite"]
def on_button_click():
return "Button clicked!"
with gr.Blocks() as demo:
gr.Markdown("## 上传 96×240×256 int16 `.bin/.raw` 文件,点击图像像素查看该像素的 256 帧直方图")
with gr.Row():
file_input = gr.File(label="上传 .raw/.bin 文件", file_types=[".raw", ".bin"])
with gr.Row():
with gr.Column():
image_display = gr.Image(interactive=False, label="点击像素显示强度曲线")
with gr.Column():
histogram = gr.Plot(label="像素强度曲线")
# Create a sample image
# sample_image = Image.fromarray(np.zeros((100, 100, 3), dtype=np.uint8))
button = gr.Button("Click me")
with gr.Row():
with gr.Column():
# im_editor_1 = gr.ImageEditor( value={
# "background": sample_image,
# "layers": [],
# "composite": sample_image,
# })
im_editor_1 = gr.ImageEditor(type="numpy", crop_size="1:1")
with gr.Column():
im_editor_2 = gr.ImageEditor(type="numpy", crop_size="1:1")
depth_image_slider = ImageSlider(label="Filter Depth Map with Slider View", elem_id='img-display-output', position=0.5)
raw_hist = gr.State()
histograms_ma = gr.State()
button.click(on_button_click)
file_input.change(
load_bin,
inputs=[file_input],
outputs=[image_display,raw_hist,histograms_ma]
)
image_display.select(
plot_pixel_histogram,
inputs=[raw_hist,histograms_ma],
outputs= histogram
)
# 图片编辑后更新预览
im_editor_1.change(
preview_editor, # 编辑后更新
inputs=[im_editor_1], # 输入:图像编辑组件
outputs=[image_display] # 输出:预览组件
)
# threshold_slider0.change(
# load_bin,
# inputs=[file_input,threshold_slider0,threshold_slider1],
# outputs=[image_display, raw_hist, nor_hist, image_display_tof0, image_display_tof1])
demo.launch()