import gradio as gr import numpy as np import plotly.graph_objs as go from scipy.ndimage import convolve from gradio_imageslider import ImageSlider import cv2 import os import ToF_utils import onnxruntime as ort from scipy.signal import find_peaks import matplotlib from PIL import Image # def mipi_raw10_to_raw8_scaled(raw10_data): # raw10_data = np.frombuffer(raw10_data, dtype=np.uint8) # n_blocks = len(raw10_data) // 5 # raw10_data = raw10_data[:n_blocks * 5].reshape(-1, 5) # p0 = (raw10_data[:, 0].astype(np.uint16) << 2) | ((raw10_data[:, 4] >> 0) & 0x03) # p1 = (raw10_data[:, 1].astype(np.uint16) << 2) | ((raw10_data[:, 4] >> 2) & 0x03) # p2 = (raw10_data[:, 2].astype(np.uint16) << 2) | ((raw10_data[:, 4] >> 4) & 0x03) # p3 = (raw10_data[:, 3].astype(np.uint16) << 2) | ((raw10_data[:, 4] >> 6) & 0x03) # raw8_data = np.empty((n_blocks * 4 * 2,), dtype=np.uint8) # raw8_data[0::8] = p0 & 0xFF # raw8_data[1::8] = p0 >> 8 # raw8_data[2::8] = p1 & 0xFF # raw8_data[3::8] = p1 >> 8 # raw8_data[4::8] = p2 & 0xFF # raw8_data[5::8] = p2 >> 8 # raw8_data[6::8] = p3 & 0xFF # raw8_data[7::8] = p3 >> 8 # return raw8_data.tobytes() # def readRAW(path): # filesize = os.path.getsize(path) # with open(path, "rb") as f: # raw_data = f.read() # # Case 1: 如果是 MIPI RAW10 格式,大小为 7,372,800 字节 # if filesize == 7372800: # raw_data = ToF_utils.mipi_raw10_to_raw8_scaled(raw_data) # # 转换为 int16 并 reshape # arr = np.frombuffer(raw_data, dtype=np.int16).reshape(96, 240, 256) # # Byte Swap: [x,y,256] → [x,y,128,2] → swap last dim → [x,y,256] # reshaped = arr.reshape(*arr.shape[:-1], -1, 2) # swapped = reshaped[..., ::-1] # histogram_data = swapped.reshape(arr.shape) # # Line remapping (每组8行:0,4,1,5,...) # mapping = [0, 4, 1, 5, 2, 6, 3, 7] # group_size = 8 # num_groups = 12 # 96 / 8 # output = np.empty_like(histogram_data) # for g in range(num_groups): # start = g * group_size # end = start + group_size # output[start:end, :, :] = histogram_data[start:end, :, :][mapping, :, :] # return output # session = ort.InferenceSession(r"D:\GitHub\dtof_depth_estimation_npu\unet1d_rnn.onnx", providers=['CPUExecutionProvider']) # session = ort.InferenceSession(r"D:\GitHub\dtof_depth_estimation_npu\unet1d_tf.onnx", providers=['CPUExecutionProvider']) cmap = matplotlib.colormaps.get_cmap('rainbow_r') # def load_bin(file,threshold_slider0,threshold_slider1): zoom_scale = 6 def load_bin(file): raw_hist = ToF_utils.readRAW(file.name) multishot = (raw_hist[..., 254] * 1024 + raw_hist[..., 255]).astype(np.float32) normalize_data = 12000 / multishot nor_hist = raw_hist * normalize_data[..., np.newaxis] bin_hist = ToF_utils.binning_2x2_stride2(raw_hist) * 0.25 / 1023 print('bin hist shape',bin_hist.shape,' ',np.max(bin_hist,axis=(2)).shape) # bin_hist = bin_hist / np.max(bin_hist[...,:-2],axis=(2))[...,np.newaxis] multishot_data = ((bin_hist[...,254]) * 1024) + ((bin_hist[...,255])) # lam = np.max(bin_hist[...,200:253],axis=2) # max_theory = lam + 5 * np.sqrt(lam) alpha = 0 dcr = 0 lam = np.median(bin_hist[...,0:253],axis=2) max_theory = lam + alpha * np.sqrt(lam) + dcr # bin_hist_rm_noise = bin_hist - max_theory[...,np.newaxis] bin_hist_rm_noise = bin_hist bin_hist_rm_noise[bin_hist_rm_noise<0] = 0 bin_hist_rm_noise[...,-2:]=0 # histograms_ma = ToF_utils.ma_vectorized(bin_hist_rm_noise,kernel=[1,1,1]) histograms_ma = ToF_utils.ma_vectorized(bin_hist_rm_noise,kernel=[1]) histograms_ma[...,-5:] =histograms_ma[...,-5,np.newaxis] histograms_ma[histograms_ma<0] = 0 tc_range = 180 bin_m = np.arange(0,tc_range,1)/tc_range/tc_range histograms_ma_nrom = histograms_ma[...,:tc_range]/multishot_data[...,np.newaxis] * 20e3 *bin_m*bin_m tc = np.sum(histograms_ma_nrom,axis=2) norm_img = (tc - tc.min()) / (tc.max() + 1e-8) img_uint8 = (norm_img * 255).astype(np.uint8) img_zoomed = np.repeat(np.repeat(img_uint8, zoom_scale, axis=0), zoom_scale, axis=1) return img_zoomed, bin_hist ,histograms_ma # input_data = histograms_ma.reshape(48*120,1,256).astype(np.float32) # outputs = session.run(None, {"input": input_data}) # result = outputs[0].squeeze(1) # print(result.shape) # result[result<0.1] = 0 # echo_num = 2 # first_two_peaks = np.full((5760, echo_num), -1, dtype=int) # for i in range(5760): # peaks, _ = find_peaks(result[i]) # local maxima indices # # Take first two from left # first_two_peaks[i, :len(peaks[:echo_num])] = peaks[:echo_num] # ref_set = np.zeros((48,120,echo_num)) # value_set = np.zeros((48,120,echo_num)) # first_two_peaks = first_two_peaks.reshape(48,120,echo_num) # print(first_two_peaks.shape) # (5760, 2) # first_two_peaks[first_two_peaks>180] = 0 # for i in range(first_two_peaks.shape[2]): # rows = np.arange(48)[:, None] # shape (H,1) # cols = np.arange(120)[None, :] # shape (1,W) # values = input_data.reshape(48,120,256)[rows, cols, first_two_peaks[...,i]] # shape (H, W) # values = values * np.power(2,13)/multishot_data*256/48000 # value_set[...,i] = values # tof = first_two_peaks[...,i] - 3 # tof[tof<0] = 1 # ref = tof * tof * values /1200 * 6 # ref_set[...,i] = ref # frame1 = ref_set[...,0] # frame2 = ref_set[...,1] * 1 # # 两个 mask 初始化为 0 # mask1 = np.zeros_like(frame1, dtype=np.uint8) # mask2 = np.zeros_like(frame2, dtype=np.uint8) # frame1[frame1 0] = 1 # mask2[(mask1 == 0) & (frame2 > 0)] = 1 # mask = np.stack([mask1,mask2], axis=-1) # 形状 (48, 120, 2) # tof = first_two_peaks[...,0] * mask1 + first_two_peaks[...,1] * mask2 # peak = value_set[...,0] * mask1 + value_set[...,1] * mask2 # # # 计算类别索引:两通道为0→0;第一个通道1→1;第二个通道1→2 # class_mask = mask.argmax(axis=2) # 直接取最大通道 # # 但是需要保证 (0,0) 的情况变回 0 # class_mask[(mask.sum(axis=-1) == 0)] = -1 # 处理无效像素 # class_mask = class_mask + 1 # totalcount = np.sum(histograms_ma,axis=2)/multishot_data # totalcount = totalcount / np.max(totalcount) # tof0 = to_uint8_image( first_two_peaks[...,0] * mask1) # tof1 = to_uint8_image( first_two_peaks[...,1] * mask2) # tof0_zoomed = np.repeat(np.repeat(tof0, 4, axis=0), 4, axis=1) # tof1_zoomed = np.repeat(np.repeat(tof1, 4, axis=0), 4, axis=1) # tof0_colored_depth = (cmap(tof0_zoomed)[:, :, :3] * 255).astype(np.uint8) # tof1_colored_depth = (cmap(tof1_zoomed)[:, :, :3] * 255).astype(np.uint8) # return img_zoomed, raw_hist, nor_hist, tof0_colored_depth, tof1_colored_depth def plot_pixel_histogram(evt: gr.SelectData, raw_hist,histograms_ma): # print("evt:", evt) x, y = evt.index # Gradio SelectData 对象 x = x // zoom_scale y = y // zoom_scale # print(raw_hist.shape) raw_values = raw_hist[y, x, :] input_data = histograms_ma.reshape(48*120,1,256).astype(np.float32) h = None session = ort.InferenceSession(r"D:\GitHub\dtof_depth_estimation_npu\unet1d_tf.onnx", providers=['CPUExecutionProvider']) outputs = session.run(None, {"input": input_data}) result = outputs[0].squeeze(1) print(result.shape) ma_values = (histograms_ma[y, x, :]) * 1023 nn_prob = result[y*120+ x, :]*np.max(ma_values) alpha = 1 lam = (np.mean(ma_values[240:254]) + np.mean(ma_values[0:10]))/2 max_theory = lam + alpha * np.sqrt(lam) # ma_values[ma_values>10]=0 fig = go.Figure() # fig.add_trace(go.Scatter(y=raw_values, mode="lines+markers",name="Raw Histogram")) fig.add_trace(go.Scatter(y=ma_values, mode="lines+markers",name="Ma Histogram (rm noise)")) fig.add_trace(go.Scatter(y=nn_prob, mode="lines+markers",name="NN prob")) # 加一条水平线,比如 y=0.5 fig.add_hline(y=max_theory, line_dash="dash", line_color="red", annotation_text="Threshold", annotation_position="top left") fig.update_layout( title=f"Pixel ({x}, {y}) 在所有 {raw_values.shape[0]} 帧的强度变化", xaxis_title="帧索引 (T)", yaxis_title="强度值", ) return fig def to_uint8_image(arr): norm = (arr) / (np.max(arr) + 1e-8) return (norm * 255).astype(np.uint8) def plot_depth(nor_hist, threshold): kernel = np.ones((3,3)) output = np.zeros((96, 240, 254)) for i in range(254): output[:, :, i] = convolve(nor_hist[:, :, i], kernel, mode='constant', cval=0) modulate1 = np.arange(1, 181) modulate = (modulate1 * modulate1) / (180 * 180) arr = output[:, :, :180] - np.max(output[:, :, :10]) tc_bin = np.sum(arr, axis=(0,1)) max_id = np.argmax(tc_bin[:-2]) pad_head = np.ones(max_id - 4) expand_kernel = np.arange(1, 13) * 0.01 pad_tail = np.ones(180 - len(pad_head) - len(expand_kernel)) expand_filter = np.concatenate([pad_head, expand_kernel, pad_tail]) print('np.max(arr,axis =(0,1,2))',np.max(arr,axis =(0,1,2))) arr_expandfilter = arr * modulate[np.newaxis, np.newaxis, :] * expand_filter[np.newaxis, np.newaxis, :] print('np.max(arr_expandfilter,axis =(0,1,2))',np.max(arr_expandfilter,axis =(0,1,2))) tof = np.argmax(arr, axis=2) tof_filter = np.argmax(arr_expandfilter, axis=2) ref = np.max(arr, axis=2) ref_filter = np.max(arr_expandfilter, axis=2) img_ref = to_uint8_image(ref) img_ref_filter = to_uint8_image(ref_filter) mask = ref_filter > threshold # 转uint8图像方便展示 img_tof = to_uint8_image(tof) img_filter = to_uint8_image(tof_filter) img_filter *= mask colored_tof = cv2.applyColorMap(img_tof, cv2.COLORMAP_VIRIDIS)[:, :, ::-1] colored_tof_filter = cv2.applyColorMap(img_filter, cv2.COLORMAP_VIRIDIS)[:, :, ::-1] return [colored_tof, colored_tof_filter] # return [img_ref, img_ref_filter] def preview_editor(im): # 返回图像编辑后的合成图像 return im["composite"] def on_button_click(): return "Button clicked!" with gr.Blocks() as demo: gr.Markdown("## 上传 96×240×256 int16 `.bin/.raw` 文件,点击图像像素查看该像素的 256 帧直方图") with gr.Row(): file_input = gr.File(label="上传 .raw/.bin 文件", file_types=[".raw", ".bin"]) with gr.Row(): with gr.Column(): image_display = gr.Image(interactive=False, label="点击像素显示强度曲线") with gr.Column(): histogram = gr.Plot(label="像素强度曲线") # Create a sample image # sample_image = Image.fromarray(np.zeros((100, 100, 3), dtype=np.uint8)) button = gr.Button("Click me") with gr.Row(): with gr.Column(): # im_editor_1 = gr.ImageEditor( value={ # "background": sample_image, # "layers": [], # "composite": sample_image, # }) im_editor_1 = gr.ImageEditor(type="numpy", crop_size="1:1") with gr.Column(): im_editor_2 = gr.ImageEditor(type="numpy", crop_size="1:1") depth_image_slider = ImageSlider(label="Filter Depth Map with Slider View", elem_id='img-display-output', position=0.5) raw_hist = gr.State() histograms_ma = gr.State() button.click(on_button_click) file_input.change( load_bin, inputs=[file_input], outputs=[image_display,raw_hist,histograms_ma] ) image_display.select( plot_pixel_histogram, inputs=[raw_hist,histograms_ma], outputs= histogram ) # 图片编辑后更新预览 im_editor_1.change( preview_editor, # 编辑后更新 inputs=[im_editor_1], # 输入:图像编辑组件 outputs=[image_display] # 输出:预览组件 ) # threshold_slider0.change( # load_bin, # inputs=[file_input,threshold_slider0,threshold_slider1], # outputs=[image_display, raw_hist, nor_hist, image_display_tof0, image_display_tof1]) demo.launch()