histOSM / inference_tab /inference_setup.py
muk42's picture
added tiling num to annotations
6911988
import gradio as gr
import cv2
import numpy as np
from PIL import Image
import os
TILE_SIZE = 1024
TILE_FOLDER = "tiles"
os.makedirs(TILE_FOLDER, exist_ok=True)
tiles_cache = {"tiles": [], "selected_tile": None}
def make_tiles(image, tile_size=TILE_SIZE):
h, w, _ = image.shape
annotated = image.copy()
tiles = []
tile_id = 0
for y in range(0, h, tile_size):
for x in range(0, w, tile_size):
tile = image[y:y+tile_size, x:x+tile_size]
tiles.append(((x, y, x+tile_size, y+tile_size), tile))
cv2.rectangle(annotated, (x, y), (x+tile_size, y+tile_size), (255,0,0), 2)
cv2.putText(annotated, str(tile_id), (x+50, y+50),
cv2.FONT_HERSHEY_SIMPLEX, 2, (0,0,0), 5)
tile_id += 1
return annotated, tiles
def create_tiles(image_file):
img = Image.open(image_file.name).convert("RGB")
img = np.array(img)
annotated, tiles = make_tiles(img, TILE_SIZE)
tiles_cache["tiles"] = []
for idx, (coords, tile) in enumerate(tiles):
tile_path = os.path.join(TILE_FOLDER, f"tile_{idx}.png")
Image.fromarray(tile).save(tile_path)
tiles_cache["tiles"].append((coords, tile_path)) # store path instead of array
tiles_cache["selected_tile"] = None
return annotated, gr.update(interactive=False)
def select_tile(evt: gr.SelectData,state):
# compute tile index
if not tiles_cache["tiles"]:
return None, gr.update(interactive=False), state
num_tiles_x = (tiles_cache["tiles"][-1][0][2]) // TILE_SIZE
tile_id = (evt.index[1] // TILE_SIZE) * num_tiles_x + (evt.index[0] // TILE_SIZE)
if 0 <= tile_id < len(tiles_cache["tiles"]):
coords, tile_path = tiles_cache["tiles"][tile_id]
# store the path, not the array
tiles_cache["selected_tile"] = {
"tile_path": tile_path,
"coords": coords
}
updated_state = {
"tile_path": tile_path,
"coords": coords
}
# load tile only for display
tile_array = np.array(Image.open(tile_path))
cv2.putText(tile_array, str(tile_id), (100, 100),
cv2.FONT_HERSHEY_SIMPLEX, 2, (0,0,0), 4, cv2.LINE_AA)
return tile_array, gr.update(interactive=True),updated_state
return None, gr.update(interactive=False), state
def enable_textbox(file):
return gr.update(interactive=bool(file))
def get_inference_widgets(run_inference,georefImg,selected_tile_state):
with gr.Row():
# Left column
with gr.Column(scale=1,min_width=500):
annotated_out = gr.Image(
type="numpy", label="City Map",
height=500, width=500
)
image_input = gr.File(label="Select Image File")
gcp_input = gr.File(label="Select GCP Points File", file_types=[".points"])
city_name = gr.Textbox(label="Enter city name")
user_crs = gr.Textbox(label="Enter CRS for the GCP",value="3395")
create_btn = gr.Button("Create Tiles")
georef_btn = gr.Button("Georeference Full Map")
# Right column
with gr.Column(scale=1):
selected_tile = gr.Image(
type="numpy", label="Selected Tile",
height=500, width=500
)
score_th = gr.Textbox(label="Score threshold below which to annotate manually (OSM)",
info="Computes fuzzy match of the detected street names with OSM street names within 100m buffer")
# Historic dictionary of street names and matching score threshold
hist_dic = gr.File(label="Upload csv with historic street names",file_types=[".csv"])
hist_th = gr.Textbox(label="Score threshold below which to annotate manually (Directory)",
info="Computes fuzzy match of the detected street names with the historic street names",
interactive=False)
hist_dic.change(enable_textbox, inputs=hist_dic, outputs=hist_th)
run_button = gr.Button("Run Inference", interactive=False)
output = gr.Textbox(label="Progress", lines=5, interactive=False)
download_file = gr.File(label="Download CSV",
file_types=[".csv"],
type="filepath")
# pass globally instead
#selected_tile_state = gr.State()
# Wire events
create_btn.click(
fn=create_tiles, inputs=image_input,
outputs=[annotated_out, run_button]
)
annotated_out.select(
fn=select_tile, inputs=[selected_tile_state],
outputs=[selected_tile, run_button, selected_tile_state]
)
run_button.click(
fn=run_inference,
inputs=[selected_tile_state, gcp_input,user_crs, city_name, score_th, hist_th,hist_dic],
outputs=[output, download_file]
)
georef_btn.click(
fn=georefImg,
inputs=[image_input, gcp_input,user_crs],
outputs=[output]
)
return image_input, gcp_input, city_name, user_crs, score_th, hist_th,hist_dic, run_button, output, download_file