import src.constants as const import numpy as np from PIL import Image import matplotlib.pyplot as plt from matplotlib.colors import LinearSegmentedColormap import streamlit as st import io def __pre_process_rain_radar_image(image): # image = ((image * 549.0) / 255.0) image = np.clip(image, const.RAIN_RADAR_LEFT_CUTOFF, const.RAIN_RADAR_RIGHT_CUTOFF) image = np.log(np.add(image, 1)) / np.log(const.RAIN_RADAR_RIGHT_CUTOFF + 1) return image def __pre_process_sat_image(image): # Correct the pixel conversions # image = ((image * 89.0) / 255.0) - 69 image = np.clip(image, -69, 20) image = -((image - const.SAT_MEAN) / const.SAT_STD) image = (image + abs(const.SAT_MIN)) / (abs(const.SAT_MIN) + const.SAT_MAX) return image def __pre_process_wind_u(image): # Correct the pixel conversions # image = ((image * 48.0) / 255.0) - 16 image = np.clip(image, -16, 32) image = ((image - const.WIND_U_MEAN) / const.WIND_U_STD) image = (image + abs(const.WIND_U_MIN)) / (abs(const.WIND_U_MIN) + const.WIND_U_MAX) return image def __pre_process_wind_v(image): # Correct the pixel conversions # image = ((image * 45.0) / 255.0) - 22 image = np.clip(image, -22, 23) image = ((image - const.WIND_V_MEAN) / const.WIND_V_STD) image = (image + abs(const.WIND_V_MIN)) / (abs(const.WIND_V_MIN) + const.WIND_V_MAX) return image def __process_pixel_correction(images_dic): file_name_arr = ["rr_0.png", "rr_15.png","rr_30.png","rr_45.png","rr_60.png", "wu_0.png", "wu_60.png", "wv_0.png", "wv_60.png", "sat_0.png", "sat_60.png"] # Denormalize from 0-255 to original range using saved constants def denormalize(data, constants): data = data.astype(float) log_val = data / 255 * (constants['max_log_val'] - constants['min_log_val']) + constants['min_log_val'] return np.exp(log_val) + constants['offset'] - 1 for i in range(5): images_dic[file_name_arr[i]] = denormalize(np.array(images_dic[file_name_arr[i]]), const.RAIN_PIXEL_CORR) for i in range(5,7): images_dic[file_name_arr[i]] = denormalize(np.array(images_dic[file_name_arr[i]]), const.WU_PIXEL_CORR) for i in range(7,9): images_dic[file_name_arr[i]] = denormalize(np.array(images_dic[file_name_arr[i]]), const.WV_PIXEL_CORR) for i in range(9,11): images_dic[file_name_arr[i]] = denormalize(np.array(images_dic[file_name_arr[i]]), const.SAT_PIXEL_CORR) # images_dic['target.png'][images_dic['target.png']==255] = 1 return images_dic def process_input_seq(images_dic): images_dic = __process_pixel_correction(images_dic) images_dic['rr_0.png'] = __pre_process_rain_radar_image(images_dic['rr_0.png']) images_dic['rr_15.png'] = __pre_process_rain_radar_image(images_dic['rr_15.png']) images_dic['rr_30.png'] = __pre_process_rain_radar_image(images_dic['rr_30.png']) images_dic['rr_45.png'] = __pre_process_rain_radar_image(images_dic['rr_45.png']) images_dic['rr_60.png'] = __pre_process_rain_radar_image(images_dic['rr_60.png']) images_dic['sat_0.png'] = __pre_process_sat_image(images_dic['sat_0.png']) images_dic['sat_60.png'] = __pre_process_sat_image(images_dic['sat_60.png']) images_dic['wu_0.png'] = __pre_process_wind_u(images_dic['wu_0.png']) images_dic['wu_60.png'] = __pre_process_wind_u(images_dic['wu_60.png']) images_dic['wv_0.png'] = __pre_process_wind_v(images_dic['wv_0.png']) images_dic['wv_60.png'] = __pre_process_wind_v(images_dic['wv_60.png']) return images_dic, 1 def remove_zero_pad(image): dummy = np.argwhere(image < 245) # assume blackground is zero max_y = dummy[:, 0].max() min_y = dummy[:, 0].min() min_x = dummy[:, 1].min() max_x = dummy[:, 1].max() crop_image = image[min_y:max_y, min_x:max_x] return crop_image def fig2img(img): buf = io.BytesIO() fig, ax = plt.subplots() ax.set_axis_off() # remove axis ticks and labels fig.tight_layout(pad=0) ax.imshow(img, cmap='viridis') fig.savefig(buf) buf.seek(0) img = Image.open(buf) return img def plot_seq(seq): col1, col2, col3, col4, col5, col6 = st.columns(6) col1.image(fig2img(seq.get("rr_0.png")), use_column_width=True, caption="Rain Radar at t = 0", ) col2.image(fig2img(seq.get("rr_15.png")), use_column_width=True, caption="Rain Radar at t = 15") col3.image(fig2img(seq.get("rr_30.png")), use_column_width=True, caption="Rain Radar at t = 30") col4.image(fig2img(seq.get("rr_45.png")), use_column_width=True, caption="Rain Radar at t = 45") col5.image(fig2img(seq.get("rr_60.png")), use_column_width=True, caption="Rain Radar at t = 60") col6.image(fig2img(seq.get("wu_0.png")), use_column_width=True, caption="Wind U Component at t = 0") col1.image(fig2img(seq.get("wu_60.png")), use_column_width=True, caption="Wind U Component at t = 60") col2.image(fig2img(seq.get("wv_0.png")), use_column_width=True, caption="Wind V Component at t = 0") col3.image(fig2img(seq.get("wv_60.png")), use_column_width=True, caption="Wind V Component at t = 60") col4.image(fig2img(seq.get("sat_0.png")), use_column_width=True, caption="Satellite at t = 0") col5.image(fig2img(seq.get("sat_60.png")), use_column_width=True, caption="Satellite at t = 60") def fig2img_overlap(img, overlap): base_cmap = plt.cm.Reds # Create a new colormap from the base, modifying the alpha values colors = base_cmap(np.arange(base_cmap.N)) half_index = base_cmap.N // 8 colors[:half_index, -1] = 0 # np.exp(np.linspace(0, 1, base_cmap.N)) / np.exp(1) # Modify alpha channel hot_alpha = LinearSegmentedColormap.from_list('hot_alpha', colors, base_cmap.N) buf = io.BytesIO() fig, ax = plt.subplots() ax.set_axis_off() # remove axis ticks and labels fig.tight_layout(pad=0) ax.imshow(img, cmap="gray") ax.imshow(overlap, cmap=hot_alpha, alpha=0.7) fig.savefig(buf) buf.seek(0) img = Image.open(buf) return img def plot_seq_with_overlap(seq, overlap_seq): col1, col2, col3, col4, col5, col6 = st.columns(6) col1.image(fig2img_overlap(seq.get("rr_0.png"), overlap_seq.get("rr_0.png")), use_column_width=True, caption="Rain Radar at t = 0", ) col2.image(fig2img_overlap(seq.get("rr_15.png"), overlap_seq.get("rr_15.png")), use_column_width=True, caption="Rain Radar at t = 15") col3.image(fig2img_overlap(seq.get("rr_30.png"), overlap_seq.get("rr_30.png")), use_column_width=True, caption="Rain Radar at t = 30") col4.image(fig2img_overlap(seq.get("rr_45.png"), overlap_seq.get("rr_45.png")), use_column_width=True, caption="Rain Radar at t = 45") col5.image(fig2img_overlap(seq.get("rr_60.png"), overlap_seq.get("rr_60.png")), use_column_width=True, caption="Rain Radar at t = 60") col6.image(fig2img_overlap(seq.get("wu_0.png"), overlap_seq.get("wu_0.png")), use_column_width=True, caption="Wind U Component at t = 0") col1.image(fig2img_overlap(seq.get("wu_60.png"), overlap_seq.get("wu_60.png")), use_column_width=True, caption="Wind U Component at t = 60") col2.image(fig2img_overlap(seq.get("wv_0.png"), overlap_seq.get("wv_0.png")), use_column_width=True, caption="Wind V Component at t = 0") col3.image(fig2img_overlap(seq.get("wv_60.png"), overlap_seq.get("wv_60.png")), use_column_width=True, caption="Wind V Component at t = 60") col4.image(fig2img_overlap(seq.get("sat_0.png"), overlap_seq.get("sat_0.png")), use_column_width=True, caption="Satellite at t = 0") col5.image(fig2img_overlap(seq.get("sat_60.png"), overlap_seq.get("sat_60.png")), use_column_width=True, caption="Satellite at t = 60") def __calculate_tp_fp_fn_tn(pred, target): tp_fp_fn_tn = [0, 0, 0, 0] diff = 2 * pred - target diff = np.array(diff) print("diff:", diff) tp_fp_fn_tn[0] = (diff == 1).sum() tp_fp_fn_tn[1] = (diff == 2).sum() tp_fp_fn_tn[2] = (diff == -1).sum() tp_fp_fn_tn[3] = (diff == 0).sum() return tp_fp_fn_tn def get_precision(pred, target): tp_fp_fn_tn = __calculate_tp_fp_fn_tn(pred, target) precision = tp_fp_fn_tn[0] / (tp_fp_fn_tn[0] + tp_fp_fn_tn[1]) return precision def get_recall(pred, target): tp_fp_fn_tn = __calculate_tp_fp_fn_tn(pred, target) recall = tp_fp_fn_tn[0] / (tp_fp_fn_tn[0] + tp_fp_fn_tn[2]) return recall def get_f1(pred, target): precision = get_precision(pred, target) recall = get_recall(pred, target) f1 = 2 * precision * recall / (precision + recall) return f1 def get_csi(pred, target): tp_fp_fn_tn = __calculate_tp_fp_fn_tn(pred, target) csi = tp_fp_fn_tn[0] / (tp_fp_fn_tn[0] + tp_fp_fn_tn[1] + tp_fp_fn_tn[2]) return csi