muk42 commited on
Commit
95327ef
·
1 Parent(s): 01f345d

added CRS conversion, plotly map instead of Folium

Browse files
app.py CHANGED
@@ -1,6 +1,8 @@
1
  # [DEBUG]
2
  from osgeo import gdal
3
 
 
 
4
  import gradio as gr
5
  import logging
6
  from inference_tab import get_inference_widgets, run_inference,georefImg
@@ -14,7 +16,7 @@ logging.basicConfig(level=logging.DEBUG)
14
 
15
  with gr.Blocks() as demo:
16
  with gr.Tab("Inference"):
17
- image_input, gcp_input, city_name, score_th, run_button, output, download_file = get_inference_widgets(run_inference,georefImg)
18
  with gr.Tab("Annotation"):
19
  get_annotation_widgets()
20
  with gr.Tab("Map"):
 
1
  # [DEBUG]
2
  from osgeo import gdal
3
 
4
+ import os
5
+ from config import OUTPUT_DIR
6
  import gradio as gr
7
  import logging
8
  from inference_tab import get_inference_widgets, run_inference,georefImg
 
16
 
17
  with gr.Blocks() as demo:
18
  with gr.Tab("Inference"):
19
+ image_input, gcp_input, city_name,user_crs, score_th, run_button, output, download_file = get_inference_widgets(run_inference,georefImg)
20
  with gr.Tab("Annotation"):
21
  get_annotation_widgets()
22
  with gr.Tab("Map"):
inference_tab/inference_logic.py CHANGED
@@ -13,7 +13,7 @@ import rasterio.features
13
  from shapely.geometry import shape
14
  import pandas as pd
15
  import osmnx as ox
16
- from osgeo import gdal
17
  import geopandas as gpd
18
  from rapidfuzz import process, fuzz
19
  from huggingface_hub import hf_hub_download
@@ -24,6 +24,7 @@ from .helpers import box_inside_global,nms_iou,non_max_suppression,tile_image_wi
24
  from pyproj import Transformer
25
  import shutil
26
  import re
 
27
 
28
  # Global cache
29
  _trocr_processor = None
@@ -471,17 +472,18 @@ def georefTile(tile_coords, gcp_path):
471
 
472
 
473
 
474
- def georefImg(image_path, gcp_path):
475
-
476
- yield "Reading GCP CSV..."
477
 
478
  TMP_FILE = os.path.join(OUTPUT_DIR,"tmp.tif")
479
  GEO_FILE = os.path.join(OUTPUT_DIR,"georeferenced.tif")
 
480
 
481
  for f in [TMP_FILE, GEO_FILE]:
482
  if os.path.exists(f):
483
  os.remove(f)
484
 
 
 
485
  df = pd.read_csv(gcp_path)
486
 
487
  H,W,_ = img_shape(image_path)
@@ -489,7 +491,7 @@ def georefImg(image_path, gcp_path):
489
 
490
  # Build GCPs
491
  gcps = []
492
- for _, r in df.iterrows():
493
  gcps.append(
494
  gdal.GCP(
495
  float(r['mapX']),
@@ -499,28 +501,64 @@ def georefImg(image_path, gcp_path):
499
  #H-float(r['sourceY'])
500
  abs(float(r['sourceY']))
501
  )
502
- )
503
 
504
- gdal.Translate(
 
 
 
 
 
 
 
 
 
 
 
505
  TMP_FILE,
506
  image_path,
507
  format="GTiff",
508
  GCPs=gcps,
509
  outputSRS="EPSG:3857"
510
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
511
 
512
  gdal.Warp(
513
  GEO_FILE,
514
- TMP_FILE,
515
  dstSRS="EPSG:3857",
516
  resampleAlg="near",
517
  polynomialOrder=1,
518
- creationOptions=["COMPRESS=LZW"]
 
519
  )
520
 
521
 
522
 
523
- yield "Georeferencing is done."
524
 
525
 
526
  def extractStreetNet(city_name):
 
13
  from shapely.geometry import shape
14
  import pandas as pd
15
  import osmnx as ox
16
+ from osgeo import gdal, osr
17
  import geopandas as gpd
18
  from rapidfuzz import process, fuzz
19
  from huggingface_hub import hf_hub_download
 
24
  from pyproj import Transformer
25
  import shutil
26
  import re
27
+ from pyproj import Transformer
28
 
29
  # Global cache
30
  _trocr_processor = None
 
472
 
473
 
474
 
475
+ def georefImg(image_path, gcp_path, user_crs):
 
 
476
 
477
  TMP_FILE = os.path.join(OUTPUT_DIR,"tmp.tif")
478
  GEO_FILE = os.path.join(OUTPUT_DIR,"georeferenced.tif")
479
+ VRT_FILE = os.path.join(OUTPUT_DIR,"vrt_file.vrt")
480
 
481
  for f in [TMP_FILE, GEO_FILE]:
482
  if os.path.exists(f):
483
  os.remove(f)
484
 
485
+ yield "Read GCP points..."
486
+
487
  df = pd.read_csv(gcp_path)
488
 
489
  H,W,_ = img_shape(image_path)
 
491
 
492
  # Build GCPs
493
  gcps = []
494
+ '''for _, r in df.iterrows():
495
  gcps.append(
496
  gdal.GCP(
497
  float(r['mapX']),
 
501
  #H-float(r['sourceY'])
502
  abs(float(r['sourceY']))
503
  )
504
+ )'''
505
 
506
+ for _, r in df.iterrows():
507
+ gcps.append((
508
+ float(r['mapX']),
509
+ float(r['mapY']),
510
+ float(r['sourceX']),
511
+ #H-float(r['sourceY'])
512
+ abs(float(r['sourceY']))
513
+ ))
514
+
515
+
516
+ # OLD
517
+ '''gdal.Translate(
518
  TMP_FILE,
519
  image_path,
520
  format="GTiff",
521
  GCPs=gcps,
522
  outputSRS="EPSG:3857"
523
+ )'''
524
+
525
+ yield "Transform GCP to user specified CRS..."
526
+
527
+ # Transform GCP from user provided CRS to Web Mercator 3857
528
+ transformer=Transformer.from_crs(f"epsg:{user_crs}","epsg:3857",always_xy=True)
529
+ gcps3857=[]
530
+ for px,py,x,y in gcps:
531
+ x3857,y3857=transformer.transform(px,py)
532
+ gcp=gdal.GCP(x3857,y3857,0,x,y)
533
+ gcps3857.append(gcp)
534
+
535
+ yield "Apply GCP to the image..."
536
+
537
+ # Apply GCP to the image
538
+ src_ds=gdal.Open(image_path)
539
+ drv=gdal.GetDriverByName('VRT')
540
+ vrt_ds=drv.CreateCopy(VRT_FILE,src_ds,0)
541
+
542
+ # Set the GCPs and spatial reference system
543
+ srs3857=osr.SpatialReference()
544
+ srs3857.ImportFromEPSG(3857)
545
+ vrt_ds.SetGCPs(gcps3857,srs3857.ExportToWkt())
546
+ vrt_ds=None # close vrt to save changes
547
+
548
 
549
  gdal.Warp(
550
  GEO_FILE,
551
+ VRT_FILE, # TMP_FILE,
552
  dstSRS="EPSG:3857",
553
  resampleAlg="near",
554
  polynomialOrder=1,
555
+ creationOptions=["COMPRESS=LZW"],
556
+ format='GTiff'
557
  )
558
 
559
 
560
 
561
+ yield "The map is georeferenced."
562
 
563
 
564
  def extractStreetNet(city_name):
inference_tab/inference_setup.py CHANGED
@@ -84,9 +84,13 @@ def get_inference_widgets(run_inference,georefImg):
84
  type="numpy", label="City Map",
85
  height=500, width=500
86
  )
87
- city_name = gr.Textbox(label="Enter city name")
88
  image_input = gr.File(label="Select Image File")
89
  gcp_input = gr.File(label="Select GCP Points File", file_types=[".points"])
 
 
 
 
90
  create_btn = gr.Button("Create Tiles")
91
  georef_btn = gr.Button("Georeference Full Map")
92
 
@@ -124,9 +128,9 @@ def get_inference_widgets(run_inference,georefImg):
124
 
125
  georef_btn.click(
126
  fn=georefImg,
127
- inputs=[image_input, gcp_input],
128
  outputs=[output]
129
  )
130
 
131
 
132
- return image_input, gcp_input, city_name, score_th, run_button, output, download_file
 
84
  type="numpy", label="City Map",
85
  height=500, width=500
86
  )
87
+
88
  image_input = gr.File(label="Select Image File")
89
  gcp_input = gr.File(label="Select GCP Points File", file_types=[".points"])
90
+
91
+ city_name = gr.Textbox(label="Enter city name")
92
+ user_crs = gr.Textbox(label="Enter CRS for the GCP",value="3395")
93
+
94
  create_btn = gr.Button("Create Tiles")
95
  georef_btn = gr.Button("Georeference Full Map")
96
 
 
128
 
129
  georef_btn.click(
130
  fn=georefImg,
131
+ inputs=[image_input, gcp_input,user_crs],
132
  outputs=[output]
133
  )
134
 
135
 
136
+ return image_input, gcp_input, city_name, user_crs, score_th, run_button, output, download_file
map_tab/map_logic.py CHANGED
@@ -1,94 +1,120 @@
1
  import os
2
- import folium
3
- from folium.raster_layers import ImageOverlay
4
- from geopy.geocoders import Nominatim
5
- import rasterio
6
  import numpy as np
7
- from matplotlib import cm, colors
8
  import pandas as pd
 
9
  import pyproj
10
  import matplotlib.pyplot as plt
11
- from branca.colormap import linear
 
12
  from config import OUTPUT_DIR
 
 
13
 
14
  CELL_SIZE_M = 100 # meters
15
 
16
 
17
- def export_georeferenced_png(raster_path, png_path):
18
- """Export the raster to a georeferenced PNG that aligns with Folium ImageOverlay."""
19
- with rasterio.open(raster_path) as src:
20
- arr = src.read()
21
- if arr.shape[0] >= 3:
22
- img = arr[:3].transpose(1, 2, 0) # (H, W, RGB)
23
- else:
24
- img = arr[0]
25
- bounds = src.bounds
26
-
27
- plt.imshow(img, extent=[bounds.left, bounds.right, bounds.bottom, bounds.top])
28
- plt.axis("off")
29
- plt.savefig(png_path, bbox_inches="tight", pad_inches=0, transparent=True)
30
- plt.close()
31
-
32
-
33
  def make_map(city, show_grid, show_georef):
34
  city = city.strip()
35
  if not city:
36
- return "Please enter a city"
37
 
 
38
  geolocator = Nominatim(
39
  user_agent="histOSM_gradioAPP (maria.u.kuznetsova@gmail.com)",
40
  timeout=10
41
  )
42
  loc = geolocator.geocode(city)
43
  if loc is None:
44
- return f"Could not find '{city}'"
45
-
46
- m = folium.Map(location=[loc.latitude, loc.longitude], zoom_start=12)
47
-
 
 
 
 
 
 
 
 
 
 
 
 
48
  raster_path = os.path.join(OUTPUT_DIR, "georeferenced.tif")
49
  if not os.path.exists(raster_path):
50
- return "Georeferenced raster not found"
51
 
52
  with rasterio.open(raster_path) as src:
53
  bounds = src.bounds
54
  crs = src.crs
55
-
56
  xmin, ymin, xmax, ymax = bounds
57
  transformer = pyproj.Transformer.from_crs(crs, "EPSG:4326", always_xy=True)
58
-
59
- # Convert raster bounds to lat/lon
60
  lon0, lat0 = transformer.transform(xmin, ymin)
61
  lon1, lat1 = transformer.transform(xmax, ymax)
62
  lat_min, lat_max = sorted([lat0, lat1])
63
  lon_min, lon_max = sorted([lon0, lon1])
64
 
65
- # Overlay raster if requested
66
  if show_georef:
67
- raster_img_path = os.path.join(OUTPUT_DIR, "georeferenced_rgba.png")
68
- if not os.path.exists(raster_img_path):
69
- export_georeferenced_png(raster_path, raster_img_path)
70
-
71
- ImageOverlay(
72
- image=raster_img_path,
73
- bounds=[[lat_min, lon_min], [lat_max, lon_max]],
74
- opacity=0.85,
75
- interactive=True,
76
- ).add_to(m)
77
-
78
- # Debug markers
79
- folium.Marker([loc.latitude, loc.longitude], tooltip="City center").add_to(m)
80
- cx, cy = (xmin + xmax) / 2, (ymin + ymax) / 2
81
- clon, clat = transformer.transform(cx, cy)
82
- folium.Marker([clat, clon], tooltip="Raster center").add_to(m)
83
-
84
- # Grid overlay
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  if show_grid:
86
- _add_grid_overlay(m, transformer)
87
-
88
- return m._repr_html_()
 
 
 
 
 
 
 
 
89
 
90
 
91
- def _add_grid_overlay(m, transformer):
 
92
  grid_values = []
93
 
94
  for fname in os.listdir(OUTPUT_DIR):
@@ -122,12 +148,9 @@ def _add_grid_overlay(m, transformer):
122
  all_scores = np.concatenate([g[0].flatten() for g in grid_values])
123
  min_val, max_val = all_scores.min(), all_scores.max()
124
  if min_val == max_val:
125
- max_val = min_val + 1e-6
126
 
127
- cmap = cm.get_cmap("Reds")
128
- colormap = linear.Reds_09.scale(min_val, max_val)
129
- colormap.caption = "Average OSM Match Score"
130
- colormap.add_to(m)
131
 
132
  for grid, tile_xmin, tile_ymin, n_rows, n_cols in grid_values:
133
  for r in range(n_rows):
@@ -143,14 +166,15 @@ def _add_grid_overlay(m, transformer):
143
  y1 = y0 + CELL_SIZE_M
144
  lon0, lat0 = transformer.transform(x0, y0)
145
  lon1, lat1 = transformer.transform(x1, y1)
146
- lat_min, lat_max = sorted([lat0, lat1])
147
- lon_min, lon_max = sorted([lon0, lon1])
148
- folium.Rectangle(
149
- bounds=[[lat_min, lon_min], [lat_max, lon_max]],
150
- color=None,
151
- weight=0,
152
- fill=True,
153
- fill_color=color,
154
- fill_opacity=0.7,
155
- popup=f"{val:.2f}",
156
- ).add_to(m)
 
 
1
  import os
2
+ import io
3
+ import base64
 
 
4
  import numpy as np
 
5
  import pandas as pd
6
+ import rasterio
7
  import pyproj
8
  import matplotlib.pyplot as plt
9
+ from matplotlib import cm, colors
10
+ from geopy.geocoders import Nominatim
11
  from config import OUTPUT_DIR
12
+ import plotly.graph_objects as go
13
+ from PIL import Image
14
 
15
  CELL_SIZE_M = 100 # meters
16
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def make_map(city, show_grid, show_georef):
19
  city = city.strip()
20
  if not city:
21
+ return go.Figure().add_annotation(text="Please enter a city")
22
 
23
+ # --- Geocode city ---
24
  geolocator = Nominatim(
25
  user_agent="histOSM_gradioAPP (maria.u.kuznetsova@gmail.com)",
26
  timeout=10
27
  )
28
  loc = geolocator.geocode(city)
29
  if loc is None:
30
+ return go.Figure().add_annotation(text=f"Could not find '{city}'")
31
+
32
+ lat_center, lon_center = loc.latitude, loc.longitude
33
+ fig = go.Figure()
34
+
35
+ # --- Add city marker ---
36
+ fig.add_trace(go.Scattermapbox(
37
+ lat=[lat_center],
38
+ lon=[lon_center],
39
+ mode="markers+text",
40
+ text=["City center"],
41
+ textposition="top right",
42
+ marker=dict(size=12, color="blue")
43
+ ))
44
+
45
+ # --- Load raster ---
46
  raster_path = os.path.join(OUTPUT_DIR, "georeferenced.tif")
47
  if not os.path.exists(raster_path):
48
+ return go.Figure().add_annotation(text="Georeferenced raster not found")
49
 
50
  with rasterio.open(raster_path) as src:
51
  bounds = src.bounds
52
  crs = src.crs
 
53
  xmin, ymin, xmax, ymax = bounds
54
  transformer = pyproj.Transformer.from_crs(crs, "EPSG:4326", always_xy=True)
 
 
55
  lon0, lat0 = transformer.transform(xmin, ymin)
56
  lon1, lat1 = transformer.transform(xmax, ymax)
57
  lat_min, lat_max = sorted([lat0, lat1])
58
  lon_min, lon_max = sorted([lon0, lon1])
59
 
60
+ # --- Optional raster overlay ---
61
  if show_georef:
62
+ with rasterio.open(raster_path) as src:
63
+ arr = src.read(out_dtype="uint8")
64
+ if arr.shape[0] >= 3:
65
+ img = arr[:3].transpose(1, 2, 0)
66
+ else:
67
+ img = arr[0]
68
+
69
+ img = np.clip(img, 0, 255)
70
+ image = Image.fromarray(img)
71
+ buffer = io.BytesIO()
72
+ image.save(buffer, format="PNG")
73
+ encoded = base64.b64encode(buffer.getvalue()).decode()
74
+
75
+ fig.update_layout(mapbox_layers=[
76
+ dict(
77
+ sourcetype="image",
78
+ source="data:image/png;base64," + encoded,
79
+ coordinates=[
80
+ [lon_min, lat_max],
81
+ [lon_max, lat_max],
82
+ [lon_max, lat_min],
83
+ [lon_min, lat_min]
84
+ ],
85
+ opacity=0.65
86
+ )
87
+ ])
88
+
89
+ # Add raster center marker
90
+ cx, cy = (xmin + xmax) / 2, (ymin + ymax) / 2
91
+ clon, clat = transformer.transform(cx, cy)
92
+ fig.add_trace(go.Scattermapbox(
93
+ lat=[clat],
94
+ lon=[clon],
95
+ mode="markers+text",
96
+ text=["Raster center"],
97
+ textposition="bottom right",
98
+ marker=dict(size=10, color="red")
99
+ ))
100
+
101
+ # --- Optional grid overlay ---
102
  if show_grid:
103
+ _add_grid_overlay_plotly(fig, transformer)
104
+
105
+ # --- Layout ---
106
+ fig.update_layout(
107
+ mapbox_style="open-street-map",
108
+ mapbox_zoom=12,
109
+ mapbox_center={"lat": lat_center, "lon": lon_center},
110
+ margin={"l": 0, "r": 0, "t": 0, "b": 0},
111
+ showlegend=False
112
+ )
113
+ return fig
114
 
115
 
116
+ def _add_grid_overlay_plotly(fig, transformer):
117
+ """Add grid overlay as filled polygons on the Plotly map."""
118
  grid_values = []
119
 
120
  for fname in os.listdir(OUTPUT_DIR):
 
148
  all_scores = np.concatenate([g[0].flatten() for g in grid_values])
149
  min_val, max_val = all_scores.min(), all_scores.max()
150
  if min_val == max_val:
151
+ max_val += 1e-6
152
 
153
+ cmap = plt.get_cmap("Reds")
 
 
 
154
 
155
  for grid, tile_xmin, tile_ymin, n_rows, n_cols in grid_values:
156
  for r in range(n_rows):
 
166
  y1 = y0 + CELL_SIZE_M
167
  lon0, lat0 = transformer.transform(x0, y0)
168
  lon1, lat1 = transformer.transform(x1, y1)
169
+ lons = [lon0, lon1, lon1, lon0, lon0]
170
+ lats = [lat0, lat0, lat1, lat1, lat0]
171
+ fig.add_trace(go.Scattermapbox(
172
+ lon=lons,
173
+ lat=lats,
174
+ mode="lines",
175
+ fill="toself",
176
+ fillcolor=color,
177
+ line=dict(width=0),
178
+ hoverinfo="text",
179
+ text=f"{val:.2f}"
180
+ ))
map_tab/map_setup.py CHANGED
@@ -3,7 +3,8 @@ from .map_logic import make_map
3
 
4
 
5
  def get_map_widgets(city_component):
6
- map_output = gr.HTML(value="Map will appear here once you type a city", elem_id="map-widget")
 
7
  show_grid = gr.Checkbox(label="Draw Grid", value=False)
8
  show_georef = gr.Checkbox(label="Show Georeferenced Map", value=False)
9
 
 
3
 
4
 
5
  def get_map_widgets(city_component):
6
+ #map_output = gr.HTML(value="Map will appear here once you type a city", elem_id="map-widget")
7
+ map_output = gr.Plot(label="Map", value=None)
8
  show_grid = gr.Checkbox(label="Draw Grid", value=False)
9
  show_georef = gr.Checkbox(label="Show Georeferenced Map", value=False)
10