| import os |
| import shutil |
| import tempfile |
| import gradio as gr |
| import plotly.graph_objects as go |
|
|
| import pandas as pd |
| from time import time |
| from utils import ( |
| create_file_structure, |
| init_info_csv, |
| add_to_info_csv, |
| ) |
|
|
| from satseg.dataset import create_datasets, create_inference_dataset |
| from satseg.model import train_model, save_model, run_inference, load_model |
| from satseg.seg_result import combine_seg_maps, get_combined_map_contours |
| from satseg.geo_tools import ( |
| shapefile_to_latlong, |
| shapefile_to_grid_indices, |
| points_to_shapefile, |
| contours_to_shapefile, |
| get_tif_n_channels, |
| ) |
|
|
| DATA_DIR = "data" |
| MODEL_DIR = os.path.join(DATA_DIR, "models") |
| TIF_DIR = os.path.join(DATA_DIR, "tifs") |
| MASK_DIR = os.path.join(DATA_DIR, "masks") |
| INFO_DIR = os.path.join(DATA_DIR, "info") |
|
|
| MODEL_INFO_PATH = os.path.join(INFO_DIR, "model_data.csv") |
| DATASET_TIF_INFO_PATH = os.path.join(INFO_DIR, "dataset_tif_data.csv") |
| DATASET_MASK_INFO_PATH = os.path.join(INFO_DIR, "dataset_mask_data.csv") |
|
|
| create_file_structure( |
| [DATA_DIR, TIF_DIR, MASK_DIR, INFO_DIR], |
| [MODEL_INFO_PATH, DATASET_TIF_INFO_PATH, DATASET_MASK_INFO_PATH], |
| ) |
| init_info_csv( |
| MODEL_INFO_PATH, |
| [ |
| "Name", |
| "Architecture", |
| "# of channels", |
| "Train TIF", |
| "Train Mask", |
| "Expression", |
| "Path", |
| ], |
| ) |
| init_info_csv(DATASET_TIF_INFO_PATH, ["Name", "# of channels", "Path"]) |
| init_info_csv(DATASET_MASK_INFO_PATH, ["Name", "Class", "Path"]) |
|
|
|
|
| def gr_train_model( |
| tif_names, mask_names, model_name, expression, progress=gr.Progress() |
| ): |
| tif_paths = list(map(lambda x: os.path.join(TIF_DIR, x), tif_names)) |
| mask_paths = list(map(lambda x: os.path.join(MASK_DIR, x), mask_names)) |
| expression = expression.strip().split() |
|
|
| |
| |
| |
| progress(0, desc="Creating Dataset...") |
| with tempfile.TemporaryDirectory() as tempdir: |
| train_set, val_set = create_datasets( |
| tif_paths, mask_paths, tempdir, expression=expression |
| ) |
| progress(0.05, desc="Training Model...") |
| model, _ = train_model(train_set, val_set, "unet") |
|
|
| progress(0.95, desc="Model Trained! Saving...") |
| model_name = "_".join(model_name.split()) + ".pt" |
| model_path = os.path.join(MODEL_DIR, model_name) |
| save_model(model, model_path) |
| add_to_info_csv( |
| MODEL_INFO_PATH, |
| [ |
| model_name, |
| "UNet", |
| val_set.n_channels, |
| ";".join(tif_names), |
| ";".join(mask_names), |
| " ".join(expression), |
| model_path, |
| ], |
| ) |
| progress(1.0, desc="Done!") |
| model_df = pd.read_csv(MODEL_INFO_PATH) |
|
|
| return "Done!", model_df, gr.Dropdown.update(choices=model_df["Name"].to_list()) |
|
|
|
|
| def gr_run_inference(tif_names, model_name, progress=gr.Progress()): |
| t = time() |
| tif_paths = list(map(lambda x: os.path.join(TIF_DIR, x), tif_names)) |
| model_df = pd.read_csv(MODEL_INFO_PATH, index_col="Name") |
| model_path = model_df["Path"][model_name] |
|
|
| with tempfile.TemporaryDirectory() as tempdir: |
| progress(0, desc="Creating Dataset...") |
| dataset = create_inference_dataset( |
| tif_paths, |
| tempdir, |
| 256, |
| expression=model_df["Expression"][model_name].split(), |
| ) |
| progress(0.1, desc="Loading Model...") |
| model = load_model(model_path) |
|
|
| result_dir = os.path.join(tempdir, "infer") |
| comb_result_dir = os.path.join(tempdir, "comb") |
| os.makedirs(result_dir) |
| os.makedirs(comb_result_dir) |
| progress(0.2, desc="Running Inference...") |
| run_inference(dataset, model, result_dir) |
| progress(0.8, desc="Preparing output...") |
| combine_seg_maps(result_dir, comb_result_dir) |
| results = get_combined_map_contours(comb_result_dir) |
|
|
| file_paths = [] |
| out_dir = os.path.join(MASK_DIR, "output") |
| if os.path.exists(out_dir): |
| shutil.rmtree(out_dir) |
| os.makedirs(out_dir) |
| for tif_name, (contours, hierarchy) in results.items(): |
| tif_path = os.path.join(TIF_DIR, f"{tif_name}.tif") |
| mask_path = os.path.join(out_dir, f"{tif_name}_mask.shp") |
| zip_path = contours_to_shapefile(contours, hierarchy, tif_path, mask_path) |
| file_paths.append(zip_path) |
| print(time() - t, "seconds") |
| return file_paths |
|
|
|
|
| def gr_save_mask_file(file_objs, filenames, obj_class): |
| print("Saving file(s)...") |
| idx = 0 |
| for filename in filenames.split(";"): |
| if filename.strip() == "": |
| continue |
|
|
| filepath = os.path.join(MASK_DIR, filename.strip()) |
| obj = file_objs[idx] |
| idx += 1 |
|
|
| shutil.move(obj.name, filepath) |
| if filename.endswith(".shp"): |
| add_to_info_csv(DATASET_MASK_INFO_PATH, [filename, obj_class, filepath]) |
| print("Done!") |
|
|
| dataset_df = pd.read_csv(DATASET_MASK_INFO_PATH) |
| choices = dataset_mask_df["Name"].to_list() |
| update = gr.Dropdown.update(choices=choices) |
|
|
| return dataset_df, update, update |
|
|
|
|
| def gr_save_tif_file(file_objs, filenames): |
| print("Saving file(s)...") |
| idx = 0 |
| for filename in filenames.split(";"): |
| if filename.strip() == "": |
| continue |
|
|
| filepath = os.path.join(TIF_DIR, filename.strip()) |
| obj = file_objs[idx] |
| idx += 1 |
|
|
| shutil.copy2(obj.name, filepath) |
| n = get_tif_n_channels(filepath) |
| add_to_info_csv(DATASET_TIF_INFO_PATH, [filename, n, filepath]) |
| print("Done!") |
|
|
| dataset_df = pd.read_csv(DATASET_TIF_INFO_PATH) |
| choices = dataset_mask_df["Name"].to_list() |
| update = gr.Dropdown.update(choices=choices) |
|
|
| return dataset_df, update, update |
|
|
|
|
| def gr_generate_map(mask_name: str, token: str = "", show_grid=True, show_mask=False): |
| mask_path = os.path.join(MASK_DIR, mask_name) |
| |
| center = (7.753769, 80.691730) |
|
|
| scattermaps = [] |
| if show_grid: |
| indices = shapefile_to_grid_indices(mask_path) |
| points_to_shapefile(indices, mask_path[: -len(".shp")] + "-grid.shp") |
| scattermaps.append( |
| go.Scattermapbox( |
| lat=indices[:, 1], |
| lon=indices[:, 0], |
| mode="markers", |
| marker=go.scattermapbox.Marker(size=6), |
| ) |
| ) |
| if show_mask: |
| contours = shapefile_to_latlong(mask_path) |
| for contour in contours[38:39]: |
| lons = contour[:, 0] |
| lats = contour[:, 1] |
| scattermaps.append( |
| go.Scattermapbox( |
| fill="toself", |
| lat=lats, |
| lon=lons, |
| mode="markers", |
| marker=go.scattermapbox.Marker(size=6), |
| ) |
| ) |
|
|
| fig = go.Figure(scattermaps) |
|
|
| if token: |
| fig.update_layout( |
| mapbox=dict( |
| style="satellite-streets", |
| accesstoken=token, |
| center=go.layout.mapbox.Center(lat=center[0], lon=center[1]), |
| pitch=0, |
| zoom=7, |
| ), |
| mapbox_layers=[ |
| { |
| |
| "sourcetype": "raster", |
| "sourceattribution": "United States Geological Survey", |
| "source": [ |
| "https://basemap.nationalmap.gov/arcgis/rest/services/USGSImageryOnly/MapServer/tile/{z}/{y}/{x}" |
| ], |
| } |
| ], |
| ) |
| else: |
| fig.update_layout( |
| mapbox_style="open-street-map", |
| hovermode="closest", |
| mapbox=dict( |
| bearing=0, |
| center=go.layout.mapbox.Center(lat=center[0], lon=center[1]), |
| pitch=0, |
| zoom=7, |
| ), |
| ) |
|
|
| return fig |
|
|
|
|
| with gr.Blocks() as demo: |
| gr.Markdown( |
| """# SatSeg |
| Train models and run inference for segmentation of multispectral satellite images.""" |
| ) |
|
|
| model_df = pd.read_csv(MODEL_INFO_PATH) |
| dataset_tif_df = pd.read_csv(DATASET_TIF_INFO_PATH) |
| dataset_mask_df = pd.read_csv(DATASET_MASK_INFO_PATH) |
|
|
| with gr.Tab("Train"): |
| train_tif_names = gr.Dropdown( |
| label="TIF Files", |
| choices=dataset_tif_df["Name"].to_list(), |
| multiselect=True, |
| ) |
| train_mask_names = gr.Dropdown( |
| label="Mask files", |
| choices=dataset_mask_df["Name"].to_list(), |
| multiselect=True, |
| ) |
| train_rs_index = gr.Textbox( |
| label="Remote Sensing Index", placeholder="( c0 + c1 ) / ( c0 - c1 ) =" |
| ) |
| |
| |
| |
| train_model_name = gr.Textbox( |
| label="Model Name", placeholder="Give the model a name" |
| ) |
| train_button = gr.Button("Train") |
|
|
| train_completion = gr.Text(label="Training Status", value="Not Started") |
|
|
| with gr.Tab("Infer"): |
| infer_tif_names = gr.Dropdown( |
| label="TIF Files", |
| choices=dataset_tif_df["Name"].to_list(), |
| multiselect=True, |
| ) |
| infer_model_name = gr.Dropdown( |
| label="Model Name", |
| choices=model_df["Name"].to_list(), |
| ) |
| infer_button = gr.Button("Infer") |
|
|
| infer_mask = gr.Files(label="Output Shapefile", interactive=False) |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
|
|
| |
|
|
| with gr.Tab("Datasets"): |
| dataset_tif_df = pd.read_csv(DATASET_TIF_INFO_PATH) |
| dataset_mask_df = pd.read_csv(DATASET_MASK_INFO_PATH) |
|
|
| datasets_upload_tif = gr.File(label="Images (.tif)", file_count="multiple") |
| datasets_upload_tif_name = gr.Textbox( |
| label="TIF name", placeholder="tif_file_1.tif;tif_file_2.tif" |
| ) |
| datasets_save_uploaded_tif = gr.Button("Save") |
|
|
| datasets_upload_mask = gr.File( |
| label="Masks (Please upload all extensions (.shp, .shx, etc.))", |
| file_count="multiple", |
| ) |
| datasets_upload_mask_name = gr.Textbox( |
| label="Mask name", placeholder="mask_1.shp;mask_1.shx" |
| ) |
| datasets_mask_class_name = gr.Textbox( |
| label="Class (The name of the object you want to segment)" |
| ) |
| datasets_save_uploaded_mask = gr.Button("Save") |
|
|
| datasets_tif_table = gr.Dataframe(dataset_tif_df, label="TIFs") |
| datasets_mask_table = gr.Dataframe(dataset_mask_df, label="Masks") |
|
|
| with gr.Tab("Models"): |
| models_table = gr.Dataframe(model_df) |
|
|
| train_button.click( |
| gr_train_model, |
| inputs=[ |
| train_tif_names, |
| train_mask_names, |
| |
| train_model_name, |
| train_rs_index, |
| ], |
| outputs=[train_completion, models_table, infer_model_name], |
| ) |
|
|
| infer_button.click( |
| gr_run_inference, |
| inputs=[infer_tif_names, infer_model_name], |
| outputs=[infer_mask], |
| ) |
|
|
| datasets_upload_tif.upload( |
| lambda y: ";".join(list(map(lambda x: os.path.basename(x.orig_name), y))), |
| inputs=datasets_upload_tif, |
| outputs=datasets_upload_tif_name, |
| ) |
|
|
| datasets_upload_mask.upload( |
| lambda y: ";".join(list(map(lambda x: os.path.basename(x.orig_name), y))), |
| inputs=datasets_upload_mask, |
| outputs=datasets_upload_mask_name, |
| ) |
|
|
| |
| |
| |
| |
| |
|
|
| datasets_save_uploaded_tif.click( |
| gr_save_tif_file, |
| inputs=[datasets_upload_tif, datasets_upload_tif_name], |
| outputs=[datasets_tif_table, train_tif_names, infer_tif_names], |
| ) |
| datasets_save_uploaded_mask.click( |
| gr_save_mask_file, |
| inputs=[ |
| datasets_upload_mask, |
| datasets_upload_mask_name, |
| datasets_mask_class_name, |
| ], |
| outputs=[datasets_mask_table, train_mask_names], |
| ) |
|
|
| demo.queue(concurrency_count=10).launch(debug=True) |
|
|