muk42 commited on
Commit
c43f7d1
·
1 Parent(s): bf77b4b
inference_tab/inference_logic.py CHANGED
@@ -1,130 +1,633 @@
1
- import gradio as gr
2
- import cv2
3
  import numpy as np
4
- from PIL import Image
5
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- TILE_SIZE = 1024
8
- TILE_FOLDER = "tiles"
9
- os.makedirs(TILE_FOLDER, exist_ok=True)
10
- tiles_cache = {"tiles": [], "selected_tile": None}
11
 
 
 
 
 
 
 
 
12
 
13
- def make_tiles(image, tile_size=TILE_SIZE):
14
- h, w, _ = image.shape
15
- annotated = image.copy()
16
- tiles = []
17
- tile_id = 0
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- for y in range(0, h, tile_size):
20
- for x in range(0, w, tile_size):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  tile = image[y:y+tile_size, x:x+tile_size]
22
- tiles.append(((x, y, x+tile_size, y+tile_size), tile))
23
- cv2.rectangle(annotated, (x, y), (x+tile_size, y+tile_size), (255,0,0), 2)
24
- cv2.putText(annotated, str(tile_id), (x+50, y+50),
25
- cv2.FONT_HERSHEY_SIMPLEX, 2, (0,0,0), 5)
26
- tile_id += 1
27
- return annotated, tiles
28
-
29
- def create_tiles(image_file):
30
- img = Image.open(image_file.name).convert("RGB")
31
- img = np.array(img)
32
-
33
- annotated, tiles = make_tiles(img, TILE_SIZE)
34
- tiles_cache["tiles"] = []
35
-
36
- for idx, (coords, tile) in enumerate(tiles):
37
- tile_path = os.path.join(TILE_FOLDER, f"tile_{idx}.png")
38
- Image.fromarray(tile).save(tile_path)
39
- tiles_cache["tiles"].append((coords, tile_path)) # store path instead of array
40
-
41
- tiles_cache["selected_tile"] = None
42
- return annotated, gr.update(interactive=False)
43
-
44
- def select_tile(evt: gr.SelectData,state):
45
- # compute tile index
46
- if not tiles_cache["tiles"]:
47
- return None, gr.update(interactive=False), state
48
-
49
- num_tiles_x = (tiles_cache["tiles"][-1][0][2]) // TILE_SIZE
50
- tile_id = (evt.index[1] // TILE_SIZE) * num_tiles_x + (evt.index[0] // TILE_SIZE)
51
-
52
- if 0 <= tile_id < len(tiles_cache["tiles"]):
53
- coords, tile_path = tiles_cache["tiles"][tile_id]
54
-
55
- # store the path, not the array
56
- tiles_cache["selected_tile"] = {
57
- "tile_path": tile_path,
58
- "coords": coords
59
- }
60
 
61
- updated_state = {
62
- "tile_path": tile_path,
63
- "coords": coords
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- # load tile only for display
67
  tile_array = np.array(Image.open(tile_path))
68
- cv2.putText(tile_array, str(tile_id), (100, 100),
69
- cv2.FONT_HERSHEY_SIMPLEX, 2, (0,0,0), 4, cv2.LINE_AA)
70
 
71
- return tile_array, gr.update(interactive=True),updated_state
 
72
 
73
- return None, gr.update(interactive=False), state
 
 
 
 
74
 
75
 
 
 
 
 
 
 
76
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
 
 
 
78
 
79
- def get_inference_widgets(run_inference,georefImg):
80
- with gr.Row():
81
- # Left column
82
- with gr.Column(scale=1,min_width=500):
83
- annotated_out = gr.Image(
84
- type="numpy", label="City Map",
85
- height=500, width=500
86
- )
87
- city_name = gr.Textbox(label="Enter city name")
88
- image_input = gr.File(label="Select Image File")
89
- gcp_input = gr.File(label="Select GCP Points File", file_types=[".points"])
90
- create_btn = gr.Button("Create Tiles")
91
- georef_btn = gr.Button("Georeference Full Map")
92
-
93
-
94
- # Right column
95
- with gr.Column(scale=1):
96
- selected_tile = gr.Image(
97
- type="numpy", label="Selected Tile",
98
- height=500, width=500
99
- )
100
- score_th = gr.Textbox(label="Enter a score threshold below which to annotate manually")
101
- run_button = gr.Button("Run Inference", interactive=False)
102
- output = gr.Textbox(label="Progress", lines=5, interactive=False)
103
- download_file = gr.File(label="Download CSV")
104
 
105
- selected_tile_path = gr.State()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
 
108
- # Wire events
109
- create_btn.click(
110
- fn=create_tiles, inputs=image_input,
111
- outputs=[annotated_out, run_button]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  )
113
- annotated_out.select(
114
- fn=select_tile, inputs=[selected_tile_path],
115
- outputs=[selected_tile, run_button, selected_tile_path]
 
 
 
 
 
116
  )
117
- run_button.click(
118
- fn=run_inference,
119
- inputs=[selected_tile_path, gcp_input, city_name, score_th],
120
- outputs=[output, download_file]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  )
122
 
123
- georef_btn.click(
124
- fn=georefImg,
125
- inputs=[image_input, gcp_input],
126
- outputs=[output]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  )
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
- return image_input, gcp_input, city_name, score_th, run_button, output, download_file
 
1
+ import spaces
 
2
  import numpy as np
3
+ from ultralytics import YOLO
4
  import os
5
+ import json
6
+ from PIL import Image
7
+ from ultralytics import SAM
8
+ import cv2
9
+ import torch
10
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
11
+ import rasterio
12
+ import rasterio.features
13
+ from shapely.geometry import shape
14
+ import pandas as pd
15
+ import osmnx as ox
16
+ from osgeo import gdal
17
+ import geopandas as gpd
18
+ from rapidfuzz import process, fuzz
19
+ from huggingface_hub import hf_hub_download
20
+ from config import OUTPUT_DIR
21
+ from pathlib import Path
22
+ from PIL import Image
23
+ from .helpers import box_inside_global,nms_iou,non_max_suppression,tile_image_with_overlap,compute_iou,merge_boxes,box_area,is_contained,merge_boxes_iterative,get_corner_points,sample_negative_points_outside_boxes,get_inset_corner_points,processYOLOBoxes,prepare_tiles,merge_tile_masks,chunkify,img_shape,best_street_match
24
+ from pyproj import Transformer
25
+ import shutil
26
+
27
+ # Global cache
28
+ _trocr_processor = None
29
+ _trocr_model = None
30
+ _trocr_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+
32
+
33
+
34
+
35
+ def run_inference(tile_dict, gcp_path, city_name, score_th):
36
+ IMAGE_FOLDER = os.path.join(OUTPUT_DIR, "blobs")
37
+ CSV_FILE = os.path.join(OUTPUT_DIR, "annotations.csv")
38
+ MASK_FILE = os.path.join(OUTPUT_DIR, "mask.tif")
39
+
40
+
41
+ if os.path.exists(IMAGE_FOLDER):
42
+ shutil.rmtree(IMAGE_FOLDER)
43
+ os.makedirs(IMAGE_FOLDER, exist_ok=True)
44
+
45
+ if os.path.exists("tmp"):
46
+ shutil.rmtree("tmp")
47
+ os.makedirs("tmp", exist_ok=True)
48
+
49
+
50
+
51
+ if os.path.exists(CSV_FILE):
52
+ os.remove(CSV_FILE)
53
+ if os.path.exists(MASK_FILE):
54
+ os.remove(MASK_FILE)
55
 
 
 
 
 
56
 
57
+ log = ""
58
+ if tile_dict is None:
59
+ yield "No tile selected", None
60
+ return
61
+
62
+ image_path = tile_dict["tile_path"]
63
+ coords = tile_dict["coords"] # (x_start, y_start, x_end, y_end)
64
 
65
+ print(f"Tile path: {image_path}; Tile coords: {coords}")
66
+
67
+ # ==== TEXT DETECTION ====
68
+ for msg in getBBoxes(image_path):
69
+ log += msg + "\n"
70
+ yield log, None
71
+ for msg in getSegments(image_path):
72
+ if msg.endswith(".tif"):
73
+ log += f"Mask saved at {msg}.\n"
74
+ yield log, msg
75
+ else:
76
+ log += msg + "\n"
77
+ yield log, None
78
+ for msg in extractSegments(image_path):
79
+ log += msg + "\n"
80
+ yield log, None
81
 
82
+ # === TEXT RECOGNITION ===
83
+ for msg in blobsOCR_all():
84
+ log += msg + "\n"
85
+ yield log, None
86
+
87
+ # === ADD GEO DATA ===
88
+
89
+ for msg in georefTile(coords,gcp_path):
90
+ log += msg + "\n"
91
+ yield log, None
92
+ '''for msg in georefImg(MASK_PATH, gcp_path):
93
+ log += msg + "\n"
94
+ yield log, None'''
95
+ for msg in extractCentroids(image_path):
96
+ log += msg + "\n"
97
+ yield log, None
98
+ for msg in extractStreetNet(city_name):
99
+ log += msg + "\n"
100
+ yield log, None
101
+
102
+ # === POST OCR ===
103
+ for msg in fuzzyMatch(score_th):
104
+ if msg.endswith(".csv"):
105
+ log+= f"Finished! CSV saved at {msg}. Street labels are ready for manual input."
106
+ yield log, msg
107
+ else:
108
+ log += msg + "\n"
109
+ yield log, None
110
+
111
+
112
+
113
+ def load_trocr_model():
114
+ """Load TrOCR into GPU if not cached."""
115
+ global _trocr_processor, _trocr_model
116
+ if _trocr_model is None:
117
+ _trocr_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-str")
118
+ _trocr_model = VisionEncoderDecoderModel.from_pretrained("muk42/trocr_streets")
119
+ _trocr_model.to(_trocr_device).eval()
120
+ return _trocr_processor, _trocr_model
121
+
122
+ @spaces.GPU
123
+ def getBBoxes(image_path, tile_size=256, overlap=0.3, confidence_threshold=0.25):
124
+ yield f"DEBUG: Received image_path: {image_path}"
125
+ image = cv2.imread(image_path)
126
+ H, W, _ = image.shape
127
+
128
+ yolo_weights = hf_hub_download(
129
+ repo_id="muk42/yolov9_streets",
130
+ filename="yolov9c_finetuned.pt")
131
+
132
+ model = YOLO(yolo_weights)
133
+
134
+
135
+ step = int(tile_size * (1 - overlap))
136
+ all_detections=[]
137
+
138
+ total_tiles = 0
139
+ # Calculate total tiles for progress reporting
140
+ for y in range(0, H, step):
141
+ for x in range(0, W, step):
142
+ # Skip small tiles at the edges
143
+ if y + tile_size > H or x + tile_size > W:
144
+ continue
145
+ total_tiles += 1
146
+
147
+ processed_tiles = 0
148
+
149
+ # Tile the image and run prediction
150
+ for y in range(0, H, step):
151
+ for x in range(0, W, step):
152
  tile = image[y:y+tile_size, x:x+tile_size]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
+ if tile.shape[0] < tile_size or tile.shape[1] < tile_size:
155
+ continue
156
+
157
+ results= model.predict(source=tile, imgsz=tile_size, conf=confidence_threshold, verbose=False, iou=0.5)
158
+
159
+ for result in results:
160
+ boxes = result.boxes.xyxy.cpu().numpy()
161
+ scores = result.boxes.conf.cpu().numpy()
162
+ classes = result.boxes.cls.cpu().numpy()
163
+
164
+ for box, score, cls in zip(boxes, scores, classes):
165
+ x1, y1, x2, y2 = box
166
+ # Shift box coordinates relative to full image
167
+ x1 += x
168
+ x2 += x
169
+ y1 += y
170
+ y2 += y
171
+ all_detections.append([x1, y1, x2, y2, float(score), int(cls)])
172
+
173
+ processed_tiles += 1
174
+ yield f"Processed tile {processed_tiles} of {total_tiles}"
175
+
176
+ # After all tiles are processed, save detections to JSON
177
+ boxes_to_save = [
178
+ {
179
+ "bbox": [float(x1), float(y1), float(x2), float(y2)],
180
+ "score": float(conf),
181
+ "class": int(cls)
182
  }
183
+ for x1, y1, x2, y2, conf, cls in all_detections
184
+ ]
185
+
186
+
187
+ BOXES_PATH = os.path.join(OUTPUT_DIR,"boxes.json")
188
+
189
+ with open(BOXES_PATH, "w") as f:
190
+ json.dump(boxes_to_save, f, indent=4)
191
+
192
+ yield f"Inference complete."
193
+
194
+
195
+
196
+ @spaces.GPU
197
+ def run_tile_inference():
198
+ model = SAM("mobile_sam.pt") # sam2.1_l.pt
199
+ Path("tmp/masks").mkdir(parents=True, exist_ok=True)
200
+ with open("tmp/tiles_meta.json", "r") as f:
201
+ tiles_meta = json.load(f)
202
+ for tile in tiles_meta:
203
+ yield f"Processing {tile['idx']}..."
204
+ tile_path = f"tmp/tiles/tile_{tile['idx']}.png"
205
+ out_path = tile_path.replace("tiles", "masks").replace(".png", ".npy")
206
+
207
+ # skip if already processed
208
+ if Path(out_path).exists():
209
+ continue
210
+
211
+
212
+ local_boxes = tile.get('local_boxes', [])
213
+ point_coords = tile.get('point_coords', [])
214
+ point_labels = tile.get('point_labels', [])
215
 
 
216
  tile_array = np.array(Image.open(tile_path))
 
 
217
 
218
+ results = model(tile_array, bboxes=local_boxes,
219
+ points=point_coords, labels=point_labels)
220
 
221
+
222
+ masks_to_save = [r.masks.data.cpu().numpy() for r in results if r.masks is not None]
223
+ if masks_to_save:
224
+ masks_stack = np.concatenate(masks_to_save, axis=0) # shape (N, H, W)
225
+ np.save(out_path, masks_stack)
226
 
227
 
228
+ def getSegments(image_path,iou=0.5,c_th=0.75,edge_margin=10):
229
+ """
230
+ iou for combining bounding boxes
231
+ c_th defined share of the smaller box contained in the larger box for merge
232
+ edge_margin pixel margin for tiles
233
+ """
234
 
235
+ yield "Load YOLO boxes.."
236
+ BOXES_PATH = os.path.join(OUTPUT_DIR,"boxes.json")
237
+ with open(BOXES_PATH, "r") as f:
238
+ box_data = json.load(f)
239
+ boxes = [b["bbox"] for b in box_data]
240
+ yield "Prepare tiles..."
241
+ H,W = prepare_tiles(image_path, boxes, tile_size=1024, overlap=50, iou=iou, c_th=c_th, edge_margin=edge_margin)
242
+ yield "Run inference on tiles..."
243
+ for msg in run_tile_inference():
244
+ yield msg
245
+ yield "Marge predicted masks into image..."
246
+ merge_tile_masks(H,W)
247
 
248
+ MASK_PATH = os.path.join(OUTPUT_DIR,"mask.tif")
249
+ yield f"{MASK_PATH}"
250
+
251
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
 
253
+ def extractSegments(image_path, min_size=500, margin=100):
254
+
255
+ image = cv2.imread(image_path)
256
+ MASK_PATH = os.path.join(OUTPUT_DIR,"mask.tif")
257
+ mask = cv2.imread(MASK_PATH, cv2.IMREAD_UNCHANGED)
258
+
259
+ height, width = mask.shape[:2]
260
+
261
+ # Get unique labels (excluding background label 0)
262
+ blob_ids = np.unique(mask)
263
+ blob_ids = blob_ids[blob_ids != 0]
264
+
265
+ yield f"Found {len(blob_ids)} blobs"
266
+
267
+ for blob_id in blob_ids:
268
+ yield f"Processing blob {blob_id}..."
269
+ # Create a binary mask for the current blob
270
+ blob_mask = (mask == blob_id).astype(np.uint8)
271
+
272
+ # Skip small blobs (WxH)
273
+ if np.sum(blob_mask) < min_size:
274
+ continue
275
+
276
+ # Find bounding box of the blob
277
+ ys, xs = np.where(blob_mask)
278
+ y_min, y_max = ys.min(), ys.max() + 1
279
+ x_min, x_max = xs.min(), xs.max() + 1
280
+
281
+ # Add margin to bounding box while keeping inside image bounds
282
+ x_min = max(0, x_min - margin)
283
+ y_min = max(0, y_min - margin)
284
+ x_max = min(width, x_max + margin)
285
+ y_max = min(height, y_max + margin)
286
+
287
+ # Crop the region from original image
288
+ cropped_image = image[y_min:y_max, x_min:x_max]
289
+ cropped_mask = blob_mask[y_min:y_max, x_min:x_max]
290
+
291
+ # Apply mask to original image
292
+ shaded = cropped_image.copy()
293
+ overlay = cropped_image.copy()
294
+ overlay[cropped_mask == 1] = (255, 200, 100)
295
+ alpha = 0.35
296
+ shaded = cv2.addWeighted(overlay, alpha, shaded, 1 - alpha, 0)
297
+
298
+ # Save the masked image
299
+ BLOB_PATH=os.path.join(OUTPUT_DIR,"blobs",f"{blob_id}.png")
300
+ cv2.imwrite(BLOB_PATH, shaded)
301
+
302
+ yield f"Done."
303
+
304
+ '''@spaces.GPU(duration=180)
305
+ def blobsOCR(image_path):
306
+ yield "Load OCR model.."
307
+ # Load model + processor
308
+ processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-str")
309
+ model = VisionEncoderDecoderModel.from_pretrained("muk42/trocr_streets")
310
+ image_extensions = (".png")
311
+ # Device setup
312
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
313
+ model.half().to(device) # float16 weights precision
314
+ yield f"Running on {device}..."
315
+ # Open output file for writing
316
+ OCR_PATH = os.path.join(OUTPUT_DIR,"ocr.csv")
317
+ with open(OCR_PATH, "w", encoding="utf-8") as f_out:
318
+ # Process each image
319
+ image_folder = os.path.join(OUTPUT_DIR,"blobs")
320
+ for filename in os.listdir(image_folder):
321
+ if filename.lower().endswith(image_extensions):
322
+ image_path = os.path.join(image_folder, filename)
323
+
324
+ try:
325
+ image = Image.open(image_path).convert("RGB")
326
+ pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(device)
327
+
328
+ generated_ids = model.generate(pixel_values)
329
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
330
+
331
+
332
+ # Write to file
333
+ name = os.path.splitext(os.path.basename(filename))[0]
334
+ f_out.write(f'{name},"{generated_text}"\n')
335
+ yield f"{filename} → {generated_text}"
336
+
337
+ except Exception as e:
338
+ yield f"Error processing {filename}: {e}"'''
339
+
340
+ @spaces.GPU
341
+ def blobsOCR_chunk(image_paths):
342
+ """Run OCR on a list of images (one chunk)."""
343
+ processor, model = load_trocr_model()
344
+ results = []
345
+
346
+ # Load all images in the chunk
347
+ images = [Image.open(path).convert("RGB") for path in image_paths]
348
+
349
+ # Convert to pixel_values tensor
350
+ pixel_values = processor(images=images, return_tensors="pt", padding=True).pixel_values.to(_trocr_device)
351
+
352
+ # Generate text for the whole batch at once
353
+ generated_ids = model.generate(pixel_values)
354
+ texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
355
+
356
+ for path, text in zip(image_paths, texts):
357
+ name = os.path.splitext(os.path.basename(path))[0]
358
+ results.append((name, text))
359
+
360
+ return results
361
+
362
+ def blobsOCR_all():
363
+ image_folder = os.path.join(OUTPUT_DIR, "blobs")
364
+ all_files = [os.path.join(image_folder, f) for f in os.listdir(image_folder) if f.endswith(".png")]
365
+
366
+ OCR_PATH = os.path.join(OUTPUT_DIR,"ocr.csv")
367
+ with open(OCR_PATH, "w", encoding="utf-8") as f_out:
368
+ for chunk in chunkify(all_files, n=16): # adjust batch size
369
+ yield f"Processing {len(chunk)} images..."
370
+ results = blobsOCR_chunk(chunk)
371
+ for name, text in results:
372
+ f_out.write(f'{name},"{text}"\n')
373
+ yield f"{name} → {text}"
374
 
375
 
376
+ def extractCentroids(image_path):
377
+ GEO_PATH=os.path.join(OUTPUT_DIR,"mask_georef.tif")
378
+ with rasterio.open(GEO_PATH) as src:
379
+ mask = src.read(1)
380
+ transform = src.transform
381
+
382
+ labels = np.unique(mask)
383
+ labels = labels[labels != 0]
384
+
385
+ data = []
386
+
387
+ # Generate polygons and their values
388
+ shapes_gen = rasterio.features.shapes(mask, mask=(mask != 0), transform=transform)
389
+
390
+ # Create a dict to collect polygons by label
391
+ polygons_by_label = {}
392
+
393
+ for geom, val in shapes_gen:
394
+ if val == 0:
395
+ continue
396
+ polygons_by_label.setdefault(val, []).append(shape(geom))
397
+
398
+ # For each label, merge polygons and get centroid
399
+ for idx, label in enumerate(labels):
400
+ yield f"Processing {idx+1} out of {len(labels)}"
401
+ polygons = polygons_by_label.get(label)
402
+ if not polygons:
403
+ continue
404
+
405
+ # Merge polygons of the same label (if multiple parts)
406
+ multi_poly = polygons[0]
407
+ for poly in polygons[1:]:
408
+ multi_poly = multi_poly.union(poly)
409
+
410
+ centroid = multi_poly.centroid
411
+ data.append({"blob_id": label, "x": centroid.x, "y": centroid.y})
412
+
413
+ df = pd.DataFrame(data)
414
+ COORD_PATH=os.path.join(OUTPUT_DIR,"centroids.csv")
415
+ df.to_csv(COORD_PATH, index=False)
416
+ yield f"Saved centroid coordinates of {len(labels)} blobs."
417
+
418
+
419
+
420
+
421
+ def georefTile(tile_coords, gcp_path):
422
+ yield "Georeferencing SAM image.."
423
+
424
+
425
+ MASK_TILE=os.path.join(OUTPUT_DIR,"mask.tif")
426
+ TMP_TILE=os.path.join(OUTPUT_DIR,"mask_tmp.tif")
427
+ MASK_TILE_GEO=os.path.join(OUTPUT_DIR,"mask_georef.tif")
428
+
429
+ for f in [TMP_TILE, MASK_TILE_GEO]:
430
+ if os.path.exists(f):
431
+ os.remove(f)
432
+
433
+ df = pd.read_csv(gcp_path)
434
+
435
+ xmin, ymin, xmax, ymax = tile_coords
436
+ xoff, yoff = xmin, ymin
437
+ xsize, ysize = xmax - xmin, ymax - ymin
438
+
439
+ shifted_gcps = []
440
+ for _, r in df.iterrows():
441
+ shifted_gcps.append(
442
+ gdal.GCP(
443
+ float(r['mapX']),
444
+ float(r['mapY']),
445
+ 0,
446
+ float(r['sourceX']) - xoff,
447
+ abs(float(r['sourceY'])) - yoff
448
+ )
449
+ )
450
+
451
+ gdal.Translate(
452
+ TMP_TILE,
453
+ MASK_TILE,
454
+ format="GTiff",
455
+ GCPs=shifted_gcps,
456
+ outputSRS="EPSG:3857"
457
  )
458
+
459
+ gdal.Warp(
460
+ MASK_TILE_GEO,
461
+ TMP_TILE,
462
+ dstSRS="EPSG:3857",
463
+ resampleAlg="near",
464
+ polynomialOrder=1,
465
+ creationOptions=["COMPRESS=LZW"]
466
  )
467
+
468
+ yield "Done."
469
+
470
+
471
+
472
+ def georefImg(image_path, gcp_path):
473
+
474
+ yield "Reading GCP CSV..."
475
+
476
+ TMP_FILE = os.path.join(OUTPUT_DIR,"tmp.tif")
477
+ GEO_FILE = os.path.join(OUTPUT_DIR,"georeferenced.tif")
478
+
479
+ for f in [TMP_FILE, GEO_FILE]:
480
+ if os.path.exists(f):
481
+ os.remove(f)
482
+
483
+ df = pd.read_csv(gcp_path)
484
+
485
+ H,W,_ = img_shape(image_path)
486
+
487
+
488
+ # Build GCPs
489
+ gcps = []
490
+ for _, r in df.iterrows():
491
+ gcps.append(
492
+ gdal.GCP(
493
+ float(r['mapX']),
494
+ float(r['mapY']),
495
+ 0,
496
+ float(r['sourceX']),
497
+ #H-float(r['sourceY'])
498
+ abs(float(r['sourceY']))
499
+ )
500
+ )
501
+
502
+
503
+
504
+
505
+
506
+ gdal.Translate(
507
+ TMP_FILE,
508
+ image_path,
509
+ format="GTiff",
510
+ GCPs=gcps,
511
+ outputSRS="EPSG:3857"
512
+ )
513
+
514
+
515
+
516
+
517
+ yield "Running gdalwarp..."
518
+
519
+ gdal.Warp(
520
+ GEO_FILE,
521
+ TMP_FILE,
522
+ dstSRS="EPSG:3857",
523
+ resampleAlg="near",
524
+ polynomialOrder=1,
525
+ creationOptions=["COMPRESS=LZW"]
526
  )
527
 
528
+
529
+
530
+ yield "Done."
531
+
532
+
533
+ def extractStreetNet(city_name):
534
+ yield f"Extract OSM street network for {city_name}"
535
+
536
+ MASK_TILE_GEO=os.path.join(OUTPUT_DIR,"mask_georef.tif")
537
+
538
+ ds = gdal.Open(MASK_TILE_GEO)
539
+ gt = ds.GetGeoTransform()
540
+ width = ds.RasterXSize
541
+ height = ds.RasterYSize
542
+
543
+ minx = gt[0]
544
+ maxy = gt[3]
545
+ maxx = gt[0] + width * gt[1] + height * gt[2]
546
+ miny = gt[3] + width * gt[4] + height * gt[5]
547
+
548
+ # Add 100 meters buffer in all directions
549
+ minx -= 100 # west
550
+ maxx += 100 # east
551
+ miny -= 100 # south
552
+ maxy += 100 # north
553
+
554
+ bbox = (maxy, miny, maxx, minx)
555
+
556
+
557
+ transformer = Transformer.from_crs("EPSG:3857", "EPSG:4326", always_xy=True)
558
+ north, south = transformer.transform(bbox[2], bbox[0])[1], transformer.transform(bbox[3], bbox[1])[1]
559
+ east, west = transformer.transform(bbox[2], bbox[0])[0], transformer.transform(bbox[3], bbox[1])[0]
560
+
561
+ bbox = (west, south, east, north)
562
+
563
+ G = ox.graph_from_bbox(bbox,network_type='all')
564
+ G_proj = ox.project_graph(G)
565
+ edges = ox.graph_to_gdfs(G_proj, nodes=False, edges=True, fill_edge_geometry=True)
566
+ edges_3857 = edges.to_crs(epsg=3857)
567
+ edges_3857 = edges_3857[['osmid','name', 'geometry']]
568
+ edges_3857 = edges_3857[edges_3857['name'].notnull()]
569
+
570
+ edges_3857['name'] = edges_3857['name'].apply(
571
+ lambda x: x[0] if isinstance(x, list) and len(x) > 0 else x)
572
+
573
+ OSM_PATH=os.path.join(OUTPUT_DIR,"osm_extract.geojson")
574
+ edges_3857.to_file(OSM_PATH, driver="GeoJSON")
575
+ yield "Done."
576
+
577
+
578
+
579
+ def fuzzyMatch(score_th):
580
+ COORD_PATH=os.path.join(OUTPUT_DIR,"centroids.csv")
581
+ OCR_PATH=os.path.join(OUTPUT_DIR,"ocr.csv")
582
+ coords_df = pd.read_csv(COORD_PATH)
583
+ names_df = pd.read_csv(OCR_PATH,
584
+ names=['blob_id','pred_text'],
585
+ dtype={"blob_id": "int64", "pred_text": "string"})
586
+ merged_df = coords_df.merge(names_df, on="blob_id")
587
+
588
+ gdf = gpd.GeoDataFrame(
589
+ merged_df,
590
+ geometry=gpd.points_from_xy(merged_df.x, merged_df.y),
591
+ crs="EPSG:3857"
592
  )
593
 
594
+ OSM_PATH=os.path.join(OUTPUT_DIR,"osm_extract.geojson")
595
+ osm_gdf = gpd.read_file(OSM_PATH,dtype={"name": "str"})
596
+
597
+ yield "Process OSM candidates..."
598
+ results = []
599
+ for _, row in gdf.iterrows():
600
+ match = best_street_match(row.geometry, row['pred_text'], osm_gdf, max_distance=100)
601
+ if match:
602
+ results.append({
603
+ "blob_id": row.blob_id,
604
+ "x": row.x,
605
+ "y": row.y,
606
+ "blob_name": row.pred_text,
607
+ "best_osm_match": match[0],
608
+ "osm_match_score": match[1]
609
+ })
610
+ else:
611
+ results.append({
612
+ "blob_id": row.blob_id,
613
+ "x": row.x,
614
+ "y": row.y,
615
+ "blob_name": row.pred_text,
616
+ "best_osm_match": None,
617
+ "osm_match_score": 0
618
+ })
619
+
620
+ results_df = pd.DataFrame(results)
621
+ RES_PATH=os.path.join(OUTPUT_DIR,"street_matches.csv")
622
+ results_df.to_csv(RES_PATH, index=False)
623
+
624
+ # remove street labels from blobs folder that are more than or equal to score threshold
625
+ manual_df = results_df[results_df['osm_match_score'] >= int(score_th)]
626
+
627
+ for blob_id in manual_df['blob_id']:
628
+ file_path = os.path.join(OUTPUT_DIR,"blobs",f"{blob_id}.png")
629
+
630
+ if os.path.exists(file_path):
631
+ os.remove(file_path)
632
 
633
+ yield f"{RES_PATH}"
inference_tab/inference_setup.py CHANGED
@@ -7,7 +7,7 @@ import os
7
  TILE_SIZE = 1024
8
  TILE_FOLDER = "tiles"
9
  os.makedirs(TILE_FOLDER, exist_ok=True)
10
- tiles_cache = {"tiles": [], "selected_tile": None, "processed_tiles": set()}
11
 
12
 
13
  def make_tiles(image, tile_size=TILE_SIZE):
@@ -20,44 +20,29 @@ def make_tiles(image, tile_size=TILE_SIZE):
20
  for x in range(0, w, tile_size):
21
  tile = image[y:y+tile_size, x:x+tile_size]
22
  tiles.append(((x, y, x+tile_size, y+tile_size), tile))
23
-
24
- # Draw thick rectangle for readability
25
- cv2.rectangle(annotated, (x, y), (x+tile_size, y+tile_size), (255, 0, 0), 6)
26
- cv2.putText(annotated, str(tile_id), (x+50, y+100),
27
- cv2.FONT_HERSHEY_SIMPLEX, 4, (0, 0, 0), 8)
28
-
29
- # Shade processed tiles
30
- if tile_id in tiles_cache["processed_tiles"]:
31
- overlay = annotated[y:y+tile_size, x:x+tile_size].copy()
32
- overlay[:] = (0, 255, 0) # light green
33
- alpha = 0.4
34
- annotated[y:y+tile_size, x:x+tile_size] = cv2.addWeighted(
35
- overlay, alpha, annotated[y:y+tile_size, x:x+tile_size], 1-alpha, 0
36
- )
37
-
38
  tile_id += 1
39
-
40
  return annotated, tiles
41
 
42
-
43
  def create_tiles(image_file):
44
  img = Image.open(image_file.name).convert("RGB")
45
- img_np = np.array(img)
46
 
47
- annotated, tiles = make_tiles(img_np, TILE_SIZE)
48
  tiles_cache["tiles"] = []
49
 
50
  for idx, (coords, tile) in enumerate(tiles):
51
  tile_path = os.path.join(TILE_FOLDER, f"tile_{idx}.png")
52
  Image.fromarray(tile).save(tile_path)
53
- tiles_cache["tiles"].append((coords, tile_path))
54
 
55
  tiles_cache["selected_tile"] = None
56
- tiles_cache["processed_tiles"] = set()
57
  return annotated, gr.update(interactive=False)
58
 
59
-
60
- def select_tile(image, evt: gr.SelectData, state):
61
  if not tiles_cache["tiles"]:
62
  return None, gr.update(interactive=False), state
63
 
@@ -66,70 +51,73 @@ def select_tile(image, evt: gr.SelectData, state):
66
 
67
  if 0 <= tile_id < len(tiles_cache["tiles"]):
68
  coords, tile_path = tiles_cache["tiles"][tile_id]
69
- tiles_cache["selected_tile"] = {"tile_path": tile_path, "coords": coords, "tile_id": tile_id}
70
 
71
- updated_state = {"tile_path": tile_path, "coords": coords, "tile_id": tile_id}
 
 
 
 
 
 
 
 
 
 
 
72
  tile_array = np.array(Image.open(tile_path))
73
  cv2.putText(tile_array, str(tile_id), (100, 100),
74
- cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 0), 4, cv2.LINE_AA)
75
- return tile_array, gr.update(interactive=True), updated_state
76
 
77
- return None, gr.update(interactive=False), state
78
 
 
79
 
80
- # Wrapper to shade processed tile after running inference
81
- def run_inference_with_shading(selected_tile_state, gcp_input, city_name, score_th, annotated_image, run_inference_fn):
82
- # Call original inference
83
- output, download_file = run_inference_fn(selected_tile_state, gcp_input, city_name, score_th)
84
 
85
- # Mark tile as processed
86
- tile_info = tiles_cache.get("selected_tile")
87
- if tile_info:
88
- tiles_cache["processed_tiles"].add(tile_info["tile_id"])
89
 
90
- # Update annotated map with shading
91
- annotated, _ = make_tiles(annotated_image, TILE_SIZE)
92
- return annotated, output, download_file
93
 
94
 
95
- def get_inference_widgets(run_inference, georefImg):
96
  with gr.Row():
97
- with gr.Column(scale=1, min_width=500):
98
- annotated_out = gr.Image(type="numpy", label="City Map", height=500, width=500)
 
 
 
 
99
  city_name = gr.Textbox(label="Enter city name")
100
  image_input = gr.File(label="Select Image File")
101
  gcp_input = gr.File(label="Select GCP Points File", file_types=[".points"])
102
  create_btn = gr.Button("Create Tiles")
103
  georef_btn = gr.Button("Georeference Full Map")
 
104
 
 
105
  with gr.Column(scale=1):
106
- selected_tile = gr.Image(type="numpy", label="Selected Tile", height=500, width=500)
 
 
 
107
  score_th = gr.Textbox(label="Enter a score threshold below which to annotate manually")
108
  run_button = gr.Button("Run Inference", interactive=False)
109
  output = gr.Textbox(label="Progress", lines=5, interactive=False)
110
  download_file = gr.File(label="Download CSV")
111
 
112
- selected_tile_state = gr.State()
113
- annotated_image_state = gr.State()
114
 
115
  # Wire events
116
  create_btn.click(
117
- fn=create_tiles,
118
- inputs=image_input,
119
  outputs=[annotated_out, run_button]
120
  )
121
-
122
  annotated_out.select(
123
- fn=select_tile,
124
- inputs=[annotated_out, selected_tile_state],
125
- outputs=[selected_tile, run_button, selected_tile_state]
126
  )
127
-
128
  run_button.click(
129
- fn=lambda selected_tile_state, gcp_input, city_name, score_th, annotated_image:
130
- run_inference_with_shading(selected_tile_state, gcp_input, city_name, score_th, annotated_image, run_inference),
131
- inputs=[selected_tile_state, gcp_input, city_name, score_th, annotated_out],
132
- outputs=[annotated_out, output, download_file]
133
  )
134
 
135
  georef_btn.click(
@@ -138,4 +126,5 @@ def get_inference_widgets(run_inference, georefImg):
138
  outputs=[output]
139
  )
140
 
141
- return image_input, gcp_input, city_name, score_th, run_button, output, download_file
 
 
7
  TILE_SIZE = 1024
8
  TILE_FOLDER = "tiles"
9
  os.makedirs(TILE_FOLDER, exist_ok=True)
10
+ tiles_cache = {"tiles": [], "selected_tile": None}
11
 
12
 
13
  def make_tiles(image, tile_size=TILE_SIZE):
 
20
  for x in range(0, w, tile_size):
21
  tile = image[y:y+tile_size, x:x+tile_size]
22
  tiles.append(((x, y, x+tile_size, y+tile_size), tile))
23
+ cv2.rectangle(annotated, (x, y), (x+tile_size, y+tile_size), (255,0,0), 2)
24
+ cv2.putText(annotated, str(tile_id), (x+50, y+50),
25
+ cv2.FONT_HERSHEY_SIMPLEX, 2, (0,0,0), 5)
 
 
 
 
 
 
 
 
 
 
 
 
26
  tile_id += 1
 
27
  return annotated, tiles
28
 
 
29
  def create_tiles(image_file):
30
  img = Image.open(image_file.name).convert("RGB")
31
+ img = np.array(img)
32
 
33
+ annotated, tiles = make_tiles(img, TILE_SIZE)
34
  tiles_cache["tiles"] = []
35
 
36
  for idx, (coords, tile) in enumerate(tiles):
37
  tile_path = os.path.join(TILE_FOLDER, f"tile_{idx}.png")
38
  Image.fromarray(tile).save(tile_path)
39
+ tiles_cache["tiles"].append((coords, tile_path)) # store path instead of array
40
 
41
  tiles_cache["selected_tile"] = None
 
42
  return annotated, gr.update(interactive=False)
43
 
44
+ def select_tile(evt: gr.SelectData,state):
45
+ # compute tile index
46
  if not tiles_cache["tiles"]:
47
  return None, gr.update(interactive=False), state
48
 
 
51
 
52
  if 0 <= tile_id < len(tiles_cache["tiles"]):
53
  coords, tile_path = tiles_cache["tiles"][tile_id]
 
54
 
55
+ # store the path, not the array
56
+ tiles_cache["selected_tile"] = {
57
+ "tile_path": tile_path,
58
+ "coords": coords
59
+ }
60
+
61
+ updated_state = {
62
+ "tile_path": tile_path,
63
+ "coords": coords
64
+ }
65
+
66
+ # load tile only for display
67
  tile_array = np.array(Image.open(tile_path))
68
  cv2.putText(tile_array, str(tile_id), (100, 100),
69
+ cv2.FONT_HERSHEY_SIMPLEX, 2, (0,0,0), 4, cv2.LINE_AA)
 
70
 
71
+ return tile_array, gr.update(interactive=True),updated_state
72
 
73
+ return None, gr.update(interactive=False), state
74
 
 
 
 
 
75
 
 
 
 
 
76
 
 
 
 
77
 
78
 
79
+ def get_inference_widgets(run_inference,georefImg):
80
  with gr.Row():
81
+ # Left column
82
+ with gr.Column(scale=1,min_width=500):
83
+ annotated_out = gr.Image(
84
+ type="numpy", label="City Map",
85
+ height=500, width=500
86
+ )
87
  city_name = gr.Textbox(label="Enter city name")
88
  image_input = gr.File(label="Select Image File")
89
  gcp_input = gr.File(label="Select GCP Points File", file_types=[".points"])
90
  create_btn = gr.Button("Create Tiles")
91
  georef_btn = gr.Button("Georeference Full Map")
92
+
93
 
94
+ # Right column
95
  with gr.Column(scale=1):
96
+ selected_tile = gr.Image(
97
+ type="numpy", label="Selected Tile",
98
+ height=500, width=500
99
+ )
100
  score_th = gr.Textbox(label="Enter a score threshold below which to annotate manually")
101
  run_button = gr.Button("Run Inference", interactive=False)
102
  output = gr.Textbox(label="Progress", lines=5, interactive=False)
103
  download_file = gr.File(label="Download CSV")
104
 
105
+ selected_tile_path = gr.State()
106
+
107
 
108
  # Wire events
109
  create_btn.click(
110
+ fn=create_tiles, inputs=image_input,
 
111
  outputs=[annotated_out, run_button]
112
  )
 
113
  annotated_out.select(
114
+ fn=select_tile, inputs=[selected_tile_path],
115
+ outputs=[selected_tile, run_button, selected_tile_path]
 
116
  )
 
117
  run_button.click(
118
+ fn=run_inference,
119
+ inputs=[selected_tile_path, gcp_input, city_name, score_th],
120
+ outputs=[output, download_file]
 
121
  )
122
 
123
  georef_btn.click(
 
126
  outputs=[output]
127
  )
128
 
129
+
130
+ return image_input, gcp_input, city_name, score_th, run_button, output, download_file