dhammo2 commited on
Commit
617c74e
·
verified ·
1 Parent(s): 8ac92fc

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +37 -0
  2. geo_boundary_translator.py +926 -0
  3. spatial_diffusion.py +1059 -0
app.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from spatial_diffusion import create_diffusion_interface
3
+ from geo_boundary_translator import create_translator_interface
4
+
5
+ def create_combined_interface():
6
+ """Create the main application with tabs for different geographic tools."""
7
+ with gr.Blocks(title="Geographic Analysis Toolkit") as combined_app:
8
+ gr.Markdown(
9
+ """
10
+ <div style="
11
+ background-color: #4B23C0;
12
+ color: white;
13
+ padding: 20px;
14
+ text-align: left;
15
+ font-size: 28px;
16
+ font-weight: bold;
17
+ margin: 0;
18
+ border-radius: 4px;
19
+ ">
20
+ MOPAC | DS &nbsp;-&nbsp;🗺️ Geographic Analysis Toolkit
21
+ </div>
22
+ """,
23
+ sanitize_html=False
24
+ )
25
+
26
+ with gr.Tabs() as tabs:
27
+ with gr.TabItem("GeoBoundary Translator"):
28
+ create_translator_interface()
29
+
30
+ with gr.TabItem("Spatial Diffusion Tool"):
31
+ create_diffusion_interface()
32
+
33
+ return combined_app
34
+
35
+ if __name__ == "__main__":
36
+ app = create_combined_interface()
37
+ app.launch()
geo_boundary_translator.py ADDED
@@ -0,0 +1,926 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import geopandas as gpd
3
+ import pandas as pd
4
+ import json
5
+ import tempfile
6
+ import os
7
+ import shutil
8
+ import matplotlib.pyplot as plt
9
+ import contextily as ctx
10
+ from matplotlib.colors import LinearSegmentedColormap
11
+ import numpy as np
12
+ import fiona
13
+ import zipfile
14
+ from typing import List, Tuple, Dict, Optional, Any, Union
15
+
16
+ def extract_columns_from_geo_file(file_obj, progress=None):
17
+ """Extract column names from a geospatial file (GeoJSON or GeoPackage)."""
18
+ try:
19
+ file_extension = os.path.splitext(file_obj.name)[1].lower()
20
+ msg = f"Processing file: {file_obj.name} with extension {file_extension}"
21
+ print(msg)
22
+ if progress is not None:
23
+ progress(0.1, desc=msg)
24
+
25
+ # Create a temporary directory and file to avoid file locking issues
26
+ temp_dir = tempfile.mkdtemp()
27
+ temp_file_path = os.path.join(temp_dir, f"temp_geo_file{file_extension}")
28
+
29
+ # Copy the file content
30
+ with open(file_obj.name, 'rb') as src_file, open(temp_file_path, 'wb') as dst_file:
31
+ dst_file.write(src_file.read())
32
+
33
+ msg = f"Created temporary copy at: {temp_file_path}"
34
+ print(msg)
35
+ if progress is not None:
36
+ progress(0.3, desc=msg)
37
+
38
+ if file_extension == '.geojson' or file_extension == '.json':
39
+ # Read GeoJSON file from the temp copy
40
+ msg = "Reading GeoJSON file..."
41
+ if progress is not None:
42
+ progress(0.5, desc=msg)
43
+ gdf = gpd.read_file(temp_file_path)
44
+ elif file_extension == '.gpkg':
45
+ # For GeoPackage, we need to handle potential multiple layers
46
+ msg = "Reading GeoPackage layers..."
47
+ if progress is not None:
48
+ progress(0.5, desc=msg)
49
+ layers = fiona.listlayers(temp_file_path)
50
+
51
+ if not layers:
52
+ raise ValueError("No layers found in GeoPackage.")
53
+
54
+ # If there's only one layer, use it directly
55
+ if len(layers) == 1:
56
+ gdf = gpd.read_file(temp_file_path, layer=layers[0])
57
+ else:
58
+ # If there are multiple layers, use the first one and warn
59
+ gdf = gpd.read_file(temp_file_path, layer=layers[0])
60
+ print(f"Multiple layers found in GeoPackage. Using '{layers[0]}'. Available layers: {layers}")
61
+ else:
62
+ raise ValueError(f"Unsupported file format: {file_extension}")
63
+
64
+ # Get column names excluding geometry
65
+ columns = [col for col in gdf.columns if col != 'geometry']
66
+ msg = f"Extracted columns: {columns}"
67
+ print(msg)
68
+ if progress is not None:
69
+ progress(0.8, desc=msg)
70
+
71
+ # Clean up the temporary directory
72
+ try:
73
+ shutil.rmtree(temp_dir)
74
+ except Exception as e:
75
+ print(f"Warning: Could not clean up temporary directory: {str(e)}")
76
+
77
+ if progress is not None:
78
+ progress(1.0, desc="File processed successfully")
79
+
80
+ return columns, gdf
81
+
82
+ except Exception as e:
83
+ error_msg = f"Error extracting columns: {str(e)}"
84
+ print(error_msg)
85
+ if progress is not None:
86
+ progress(1.0, desc=error_msg)
87
+ # Try to clean up if there was an error
88
+ try:
89
+ if 'temp_dir' in locals():
90
+ shutil.rmtree(temp_dir)
91
+ except:
92
+ pass
93
+ return [], None
94
+
95
+ def extract_columns_from_csv(file_obj, progress=None):
96
+ """Extract column names from a CSV file."""
97
+ try:
98
+ msg = f"Reading CSV file: {file_obj.name}"
99
+ print(msg)
100
+ if progress is not None:
101
+ progress(0.2, desc=msg)
102
+
103
+ # Read the CSV file
104
+ df = pd.read_csv(file_obj.name)
105
+
106
+ # Get column names
107
+ columns = df.columns.tolist()
108
+
109
+ msg = f"Extracted CSV columns: {columns}"
110
+ print(msg)
111
+ if progress is not None:
112
+ progress(1.0, desc=msg)
113
+
114
+ return columns, df
115
+ except Exception as e:
116
+ error_msg = f"Error extracting columns from CSV: {str(e)}"
117
+ print(error_msg)
118
+ if progress is not None:
119
+ progress(1.0, desc=error_msg)
120
+ return [], None
121
+
122
+ def create_map_visualization(gdf, title, progress=None):
123
+ """Create a map visualization of the GeoJSON data."""
124
+ try:
125
+ # Check if the GeoDataFrame is valid
126
+ if gdf is None:
127
+ msg = "Cannot create map: GeoDataFrame is None"
128
+ print(msg)
129
+ if progress is not None:
130
+ progress(1.0, desc=msg)
131
+ return None
132
+
133
+ msg = "Creating map visualization..."
134
+ print(msg)
135
+ if progress is not None:
136
+ progress(0.2, desc=msg)
137
+
138
+ # Create a temporary file for the map image
139
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
140
+ temp_filename = temp_file.name
141
+ temp_file.close()
142
+
143
+ # Convert to EPSG:3857 (Web Mercator) for basemap compatibility
144
+ if gdf.crs is None:
145
+ # Assume WGS84 if no CRS is specified
146
+ print("No CRS found in GeoDataFrame, assuming WGS84")
147
+ gdf = gdf.set_crs("EPSG:4326")
148
+
149
+ if progress is not None:
150
+ progress(0.4, desc="Converting coordinate system...")
151
+
152
+ gdf_webmerc = gdf.to_crs("EPSG:3857")
153
+
154
+ # Create plot
155
+ if progress is not None:
156
+ progress(0.6, desc="Creating plot...")
157
+
158
+ fig, ax = plt.subplots(1, 1, figsize=(10, 8))
159
+
160
+ # Create a custom purple colormap to match theme
161
+ colors = ['#f5f0ff', '#e6d9ff', '#d6c2ff', '#c7abff', '#b894ff', '#a87dff', '#9966ff', '#8a4fff', '#7b38ff', '#6c21ff', '#5d0af0', '#4B23C0']
162
+ custom_cmap = LinearSegmentedColormap.from_list('custom_purples', colors)
163
+
164
+ # Plot the data with random colors for each polygon
165
+ random_values = np.random.rand(len(gdf_webmerc))
166
+ gdf_webmerc.plot(ax=ax, column=random_values, cmap=custom_cmap,
167
+ alpha=0.7, edgecolor='#333333', linewidth=0.5)
168
+
169
+ # Add basemap
170
+ if progress is not None:
171
+ progress(0.8, desc="Adding basemap...")
172
+
173
+ try:
174
+ ctx.add_basemap(ax, source=ctx.providers.CartoDB.Positron)
175
+ except Exception as e:
176
+ print(f"Could not add basemap: {str(e)}")
177
+
178
+ # Add title
179
+ ax.set_title(title, fontsize=16, color='#4B23C0') # Purple title to match theme
180
+
181
+ # Remove axes
182
+ ax.set_axis_off()
183
+
184
+ # Tight layout
185
+ plt.tight_layout()
186
+
187
+ # Save figure
188
+ if progress is not None:
189
+ progress(0.9, desc="Saving map image...")
190
+
191
+ plt.savefig(temp_filename, dpi=150, bbox_inches='tight')
192
+ plt.close(fig)
193
+
194
+ if progress is not None:
195
+ progress(1.0, desc="Map created successfully")
196
+
197
+ return temp_filename
198
+
199
+ except Exception as e:
200
+ error_msg = f"Error creating map visualization: {str(e)}"
201
+ print(error_msg)
202
+ if progress is not None:
203
+ progress(1.0, desc=error_msg)
204
+ return None
205
+
206
+ def calculate_areal_intersection(original_gdf, original_id, new_gdf, new_id, progress=None):
207
+ """
208
+ Calculate the areal intersection between two geographic datasets.
209
+
210
+ Args:
211
+ original_gdf (GeoDataFrame): Original geography
212
+ original_id (str): ID column in original geography
213
+ new_gdf (GeoDataFrame): New geography
214
+ new_id (str): ID column in new geography
215
+ progress (gr.Progress, optional): Gradio progress tracker
216
+
217
+ Returns:
218
+ DataFrame: Containing the percentage overlap matrix
219
+ """
220
+ try:
221
+ total_combinations = len(original_gdf) * len(new_gdf)
222
+ msg = f"Calculating areal intersection between {len(original_gdf)} original areas and {len(new_gdf)} new areas..."
223
+ print(msg)
224
+
225
+ if progress is not None:
226
+ progress(0, desc=msg)
227
+
228
+ # Ensure both GeoDataFrames have the same CRS
229
+ if original_gdf.crs != new_gdf.crs:
230
+ crs_msg = f"Converting CRS from {original_gdf.crs} to {new_gdf.crs}"
231
+ print(crs_msg)
232
+ if progress is not None:
233
+ progress(0, desc=crs_msg)
234
+ original_gdf = original_gdf.to_crs(new_gdf.crs)
235
+
236
+ # Create empty dataframe to store results
237
+ overlap_df = pd.DataFrame(columns=['original_id', 'new_id', 'area_original', 'area_new', 'area_overlap', 'pct_of_original', 'pct_of_new'])
238
+
239
+ # Initialize progress tracking
240
+ processed = 0
241
+
242
+ # Iterate through each pair of geometries
243
+ for idx1, row1 in original_gdf.iterrows():
244
+ # Get the original ID and geometry
245
+ orig_id = row1[original_id]
246
+ orig_geom = row1['geometry']
247
+ orig_area = orig_geom.area
248
+
249
+ for idx2, row2 in new_gdf.iterrows():
250
+ # Get the new ID and geometry
251
+ new_id_val = row2[new_id]
252
+ new_geom = row2['geometry']
253
+ new_area = new_geom.area
254
+
255
+ # Check if geometries intersect
256
+ if orig_geom.intersects(new_geom):
257
+ # Calculate the intersection
258
+ intersection = orig_geom.intersection(new_geom)
259
+ intersection_area = intersection.area
260
+
261
+ # Calculate percentages
262
+ pct_of_original = (intersection_area / orig_area) * 100 if orig_area > 0 else 0
263
+ pct_of_new = (intersection_area / new_area) * 100 if new_area > 0 else 0
264
+
265
+ # Add to results dataframe if overlap is substantial (>0.01%)
266
+ if pct_of_original > 0.01 or pct_of_new > 0.01:
267
+ overlap_df = pd.concat([overlap_df, pd.DataFrame({
268
+ 'original_id': [orig_id],
269
+ 'new_id': [new_id_val],
270
+ 'area_original': [orig_area],
271
+ 'area_new': [new_area],
272
+ 'area_overlap': [intersection_area],
273
+ 'pct_of_original': [pct_of_original],
274
+ 'pct_of_new': [pct_of_new]
275
+ })], ignore_index=True)
276
+
277
+ # Update progress
278
+ processed += 1
279
+ progress_pct = processed / total_combinations
280
+
281
+ # Update the progress bar
282
+ if progress is not None:
283
+ progress_msg = f"Calculating intersections: {int(progress_pct*100)}% complete"
284
+ progress(progress_pct, desc=progress_msg)
285
+
286
+ # Also log to console every 10%
287
+ if int(progress_pct*100) % 10 == 0 and int(progress_pct*100) > 0:
288
+ print(f"Intersection calculation: {int(progress_pct*100)}% complete")
289
+
290
+ complete_msg = f"Intersection calculation complete. Found {len(overlap_df)} intersections."
291
+ print(complete_msg)
292
+ if progress is not None:
293
+ progress(1.0, desc=complete_msg)
294
+
295
+ return overlap_df
296
+
297
+ except Exception as e:
298
+ error_msg = f"Error calculating areal intersection: {str(e)}"
299
+ print(error_msg)
300
+ if progress is not None:
301
+ progress(1.0, desc=error_msg)
302
+ return pd.DataFrame()
303
+
304
+ def generate_weights_matrix(overlap_df, progress=None):
305
+ """
306
+ Generate a weights matrix from the overlap dataframe.
307
+
308
+ Args:
309
+ overlap_df (DataFrame): Output from calculate_areal_intersection
310
+ progress (gr.Progress, optional): Gradio progress tracker
311
+
312
+ Returns:
313
+ DataFrame: Weights matrix with original IDs as rows and new IDs as columns
314
+ """
315
+ try:
316
+ msg = "Generating weights matrix from intersection data..."
317
+ print(msg)
318
+ if progress is not None:
319
+ progress(0.1, desc=msg)
320
+
321
+ # Pivot the overlap dataframe to create a matrix
322
+ # Values are the percentage of the original area that goes into each new area
323
+ weights_matrix = overlap_df.pivot(
324
+ index='original_id',
325
+ columns='new_id',
326
+ values='pct_of_original'
327
+ ).fillna(0)
328
+
329
+ # Check that rows sum to approximately 100%
330
+ row_sums = weights_matrix.sum(axis=1)
331
+ stats_msg = f"Row sum statistics: min={row_sums.min():.2f}%, max={row_sums.max():.2f}%, mean={row_sums.mean():.2f}%"
332
+ print(stats_msg)
333
+
334
+ if progress is not None:
335
+ progress(1.0, desc="Weights matrix generated successfully")
336
+
337
+ return weights_matrix
338
+
339
+ except Exception as e:
340
+ error_msg = f"Error generating weights matrix: {str(e)}"
341
+ print(error_msg)
342
+ if progress is not None:
343
+ progress(1.0, desc=error_msg)
344
+ return pd.DataFrame()
345
+
346
+ def check_fields(original_file, original_id, new_file, new_id, stats_file=None, stats_id=None, stats_cols=None):
347
+ """
348
+ Check if all required fields are filled to enable operations.
349
+ Returns (translation_ready, weights_only_ready, message)
350
+ """
351
+ # Check if we can calculate weights only (no statistics needed)
352
+ weights_only_ready = (original_file is not None and original_id is not None and
353
+ new_file is not None and new_id is not None)
354
+
355
+ # Check if we can do full translation
356
+ translation_ready = (weights_only_ready and
357
+ stats_file is not None and stats_id is not None and
358
+ stats_cols is not None and len(stats_cols) > 0)
359
+
360
+ if translation_ready:
361
+ return True, True, "Ready to translate statistics"
362
+ elif weights_only_ready:
363
+ return False, True, "Ready to calculate weights matrix (no statistics will be translated)"
364
+ elif original_file is not None and new_file is not None:
365
+ return False, False, "Please select ID columns"
366
+ else:
367
+ return False, False, "Please upload required files"
368
+
369
+ def calculate_weights_only(original_file, original_id, new_file, new_id, progress=None):
370
+ """
371
+ Calculate the weights matrix between two geographies without translating statistics.
372
+
373
+ Args:
374
+ original_file: File object for original geography
375
+ original_id: ID column in original geography
376
+ new_file: File object for new geography
377
+ new_id: ID column in new geography
378
+ progress (gr.Progress, optional): Gradio progress tracker
379
+
380
+ Returns:
381
+ Tuple of (results_visible, summary_text, zip_path, weights_path)
382
+ """
383
+ try:
384
+ # Read the geographies
385
+ if progress is not None:
386
+ progress(0, desc="Reading original geography...")
387
+ print("Reading original geography...")
388
+ orig_columns, orig_gdf = extract_columns_from_geo_file(original_file, progress)
389
+
390
+ if progress is not None:
391
+ progress(0.1, desc="Reading new geography...")
392
+ print("Reading new geography...")
393
+ new_columns, new_gdf = extract_columns_from_geo_file(new_file, progress)
394
+
395
+ # Calculate areal intersection
396
+ if progress is not None:
397
+ progress(0.2, desc="Preparing to calculate areal intersection...")
398
+ print("Calculating areal intersection...")
399
+ overlap_df = calculate_areal_intersection(orig_gdf, original_id, new_gdf, new_id, progress)
400
+
401
+ if overlap_df.empty:
402
+ if progress is not None:
403
+ progress(1.0, desc="Error: Could not calculate area overlap between geographies.")
404
+ return True, "Error: Could not calculate area overlap between geographies. Check that they cover the same region.", None, None
405
+
406
+ # Generate weights matrix
407
+ if progress is not None:
408
+ progress(0.9, desc="Generating weights matrix...")
409
+ print("Generating weights matrix...")
410
+ weights_matrix = generate_weights_matrix(overlap_df, progress)
411
+
412
+ if weights_matrix.empty:
413
+ if progress is not None:
414
+ progress(1.0, desc="Error: Could not generate weights matrix.")
415
+ return True, "Error: Could not generate weights matrix.", None, None
416
+
417
+ # Save weights matrix to a CSV file
418
+ if progress is not None:
419
+ progress(0.95, desc="Saving results...")
420
+
421
+ temp_weights_file = tempfile.NamedTemporaryFile(delete=False, suffix='_weights.csv')
422
+ weights_path = temp_weights_file.name
423
+ temp_weights_file.close()
424
+
425
+ weights_matrix.to_csv(weights_path)
426
+ print(f"Saved weights matrix to {weights_path}")
427
+
428
+ # Also save the full intersection data which includes more detailed overlap information
429
+ temp_overlap_file = tempfile.NamedTemporaryFile(delete=False, suffix='_overlap_details.csv')
430
+ overlap_path = temp_overlap_file.name
431
+ temp_overlap_file.close()
432
+
433
+ overlap_df.to_csv(overlap_path, index=False)
434
+ print(f"Saved detailed overlap data to {overlap_path}")
435
+
436
+ # Create a ZIP file with all outputs
437
+ zip_file = tempfile.NamedTemporaryFile(delete=False, suffix='.zip')
438
+ zip_path = zip_file.name
439
+ zip_file.close()
440
+
441
+ with zipfile.ZipFile(zip_path, 'w') as zipf:
442
+ zipf.write(weights_path, arcname=f"weights_matrix.csv")
443
+ zipf.write(overlap_path, arcname=f"overlap_details.csv")
444
+
445
+ print(f"Created ZIP archive at {zip_path}")
446
+
447
+ # Create summary text
448
+ summary = f"""
449
+ Weights calculation complete!
450
+ - Processed {len(orig_gdf)} original areas and {len(new_gdf)} new areas.
451
+ - Found {len(overlap_df)} geographic intersections between areas.
452
+
453
+ The download contains:
454
+ - weights_matrix.csv: The weights matrix for future translations
455
+ - overlap_details.csv: Detailed area overlap information
456
+ """
457
+
458
+ if progress is not None:
459
+ progress(1.0, desc="Weights calculation complete!")
460
+
461
+ return True, summary, zip_path, weights_path
462
+
463
+ except Exception as e:
464
+ print(f"Error calculating weights: {str(e)}")
465
+ import traceback
466
+ traceback.print_exc()
467
+ if progress is not None:
468
+ progress(1.0, desc=f"Error during weights calculation: {str(e)}")
469
+ return True, f"Error during weights calculation: {str(e)}", None, None
470
+
471
+ def translate_statistics(original_file, original_id, new_file, new_id, stats_file, stats_id, stats_cols, progress=None):
472
+ """
473
+ Translate statistics from the original geography to the new geography.
474
+
475
+ Args:
476
+ original_file: File object for original geography
477
+ original_id: ID column in original geography
478
+ new_file: File object for new geography
479
+ new_id: ID column in new geography
480
+ stats_file: File object for statistics
481
+ stats_id: ID column in statistics
482
+ stats_cols: List of statistic columns to translate
483
+ progress (gr.Progress, optional): Gradio progress tracker
484
+
485
+ Returns:
486
+ Tuple of (results_visible, summary_text, output_file_path, weights_file_path)
487
+ """
488
+ try:
489
+ # Read the geographies
490
+ if progress is not None:
491
+ progress(0.05, desc="Reading original geography...")
492
+ print("Reading original geography...")
493
+ orig_columns, orig_gdf = extract_columns_from_geo_file(original_file, progress)
494
+
495
+ if progress is not None:
496
+ progress(0.1, desc="Reading new geography...")
497
+ print("Reading new geography...")
498
+ new_columns, new_gdf = extract_columns_from_geo_file(new_file, progress)
499
+
500
+ if progress is not None:
501
+ progress(0.15, desc="Reading statistics...")
502
+ print("Reading statistics...")
503
+ stats_columns, stats_df = extract_columns_from_csv(stats_file, progress)
504
+
505
+ # Check that the stats_id exists in the original geography
506
+ if stats_id not in stats_df.columns:
507
+ if progress is not None:
508
+ progress(1.0, desc=f"Error: Statistics file does not contain column '{stats_id}'")
509
+ return True, f"Error: Statistics file does not contain column '{stats_id}'", None, None
510
+
511
+ # Create lookup between stats and original geography
512
+ if progress is not None:
513
+ progress(0.2, desc="Creating ID lookup between statistics and original geography...")
514
+ print("Creating ID lookup between statistics and original geography...")
515
+ stats_ids = set(stats_df[stats_id].astype(str))
516
+ orig_ids = set(orig_gdf[original_id].astype(str))
517
+
518
+ # Check for matches
519
+ matching_ids = stats_ids.intersection(orig_ids)
520
+ missing_ids = stats_ids - orig_ids
521
+
522
+ match_percent = (len(matching_ids) / len(stats_ids)) * 100 if stats_ids else 0
523
+ match_msg = f"ID match: {len(matching_ids)}/{len(stats_ids)} ({match_percent:.1f}%)"
524
+ print(match_msg)
525
+
526
+ if match_percent < 50:
527
+ warning_msg = f"Warning: Low ID match rate ({match_percent:.1f}%). Check ID column selections."
528
+ if progress is not None:
529
+ progress(1.0, desc=warning_msg)
530
+ return True, warning_msg, None, None
531
+
532
+ # Calculate areal intersection
533
+ if progress is not None:
534
+ progress(0.25, desc="Preparing to calculate areal intersection...")
535
+ print("Calculating areal intersection...")
536
+ overlap_df = calculate_areal_intersection(orig_gdf, original_id, new_gdf, new_id, progress)
537
+
538
+ if overlap_df.empty:
539
+ error_msg = "Error: Could not calculate area overlap between geographies. Check that they cover the same region."
540
+ if progress is not None:
541
+ progress(1.0, desc=error_msg)
542
+ return True, error_msg, None, None
543
+
544
+ # Generate weights matrix
545
+ if progress is not None:
546
+ progress(0.75, desc="Generating weights matrix...")
547
+ print("Generating weights matrix...")
548
+ weights_matrix = generate_weights_matrix(overlap_df, progress)
549
+
550
+ if weights_matrix.empty:
551
+ error_msg = "Error: Could not generate weights matrix."
552
+ if progress is not None:
553
+ progress(1.0, desc=error_msg)
554
+ return True, error_msg, None, None
555
+
556
+ # Save weights matrix to a CSV file
557
+ if progress is not None:
558
+ progress(0.8, desc="Saving weights matrix...")
559
+
560
+ temp_weights_file = tempfile.NamedTemporaryFile(delete=False, suffix='_weights.csv')
561
+ weights_path = temp_weights_file.name
562
+ temp_weights_file.close()
563
+
564
+ weights_matrix.to_csv(weights_path)
565
+ print(f"Saved weights matrix to {weights_path}")
566
+
567
+ # Also save the full intersection data which includes more detailed overlap information
568
+ temp_overlap_file = tempfile.NamedTemporaryFile(delete=False, suffix='_overlap_details.csv')
569
+ overlap_path = temp_overlap_file.name
570
+ temp_overlap_file.close()
571
+
572
+ overlap_df.to_csv(overlap_path, index=False)
573
+ print(f"Saved detailed overlap data to {overlap_path}")
574
+
575
+ # Create output dataframe with new geography IDs
576
+ if progress is not None:
577
+ progress(0.85, desc="Creating output dataframe...")
578
+ print("Creating output dataframe...")
579
+ output_df = pd.DataFrame({new_id: new_gdf[new_id]})
580
+
581
+ # Translate each selected statistic
582
+ total_stats = len(stats_cols)
583
+ for i, stat_col in enumerate(stats_cols):
584
+ stat_msg = f"Translating statistic: {stat_col} ({i+1}/{total_stats})"
585
+ if progress is not None:
586
+ progress(0.85 + (0.1 * (i / total_stats)), desc=stat_msg)
587
+ print(stat_msg)
588
+
589
+ # Check if the statistic exists in the stats dataframe
590
+ if stat_col not in stats_df.columns:
591
+ print(f"Warning: Statistic column '{stat_col}' not found in statistics file.")
592
+ continue
593
+
594
+ # Merge stats with original geography
595
+ merged_df = pd.merge(
596
+ orig_gdf[[original_id]],
597
+ stats_df[[stats_id, stat_col]],
598
+ left_on=original_id,
599
+ right_on=stats_id,
600
+ how='left'
601
+ )
602
+
603
+ # Create a series with original IDs as the index and statistic values
604
+ stat_series = merged_df.set_index(original_id)[stat_col]
605
+
606
+ # Apply weights to translate statistics to new geography
607
+ new_stat = {}
608
+
609
+ for new_area_id in weights_matrix.columns:
610
+ # Get weights for this new area
611
+ area_weights = weights_matrix[new_area_id]
612
+
613
+ # Calculate weighted sum
614
+ weighted_sum = 0
615
+ total_weight = 0
616
+
617
+ for orig_area_id, weight in area_weights.items():
618
+ if orig_area_id in stat_series and not pd.isna(stat_series[orig_area_id]):
619
+ weighted_sum += stat_series[orig_area_id] * (weight / 100)
620
+ total_weight += weight / 100
621
+
622
+ # Add to new statistic dictionary
623
+ if total_weight > 0:
624
+ new_stat[new_area_id] = weighted_sum
625
+ else:
626
+ new_stat[new_area_id] = np.nan
627
+
628
+ # Add to output dataframe
629
+ output_df[stat_col] = output_df[new_id].map(new_stat)
630
+
631
+ # Save translated statistics to a CSV
632
+ if progress is not None:
633
+ progress(0.95, desc="Saving translated statistics...")
634
+
635
+ temp_output_file = tempfile.NamedTemporaryFile(delete=False, suffix='_translated_stats.csv')
636
+ output_path = temp_output_file.name
637
+ temp_output_file.close()
638
+
639
+ output_df.to_csv(output_path, index=False)
640
+ print(f"Saved translated statistics to {output_path}")
641
+
642
+ # Create a ZIP file with all outputs
643
+ zip_file = tempfile.NamedTemporaryFile(delete=False, suffix='.zip')
644
+ zip_path = zip_file.name
645
+ zip_file.close()
646
+
647
+ with zipfile.ZipFile(zip_path, 'w') as zipf:
648
+ zipf.write(output_path, arcname=f"translated_statistics.csv")
649
+ zipf.write(weights_path, arcname=f"weights_matrix.csv")
650
+ zipf.write(overlap_path, arcname=f"overlap_details.csv")
651
+
652
+ print(f"Created ZIP archive at {zip_path}")
653
+
654
+ # Create summary text
655
+ summary = f"""
656
+ Translation complete!
657
+ - Translated {len(stats_cols)} statistics from {len(orig_gdf)} original areas to {len(new_gdf)} new areas.
658
+ - ID match rate: {match_percent:.1f}%
659
+ - Found {len(overlap_df)} geographic intersections between areas.
660
+
661
+ The download contains:
662
+ - translated_statistics.csv: The statistics mapped to the new geography
663
+ - weights_matrix.csv: The weights matrix for future translations
664
+ - overlap_details.csv: Detailed area overlap information
665
+ """
666
+
667
+ if progress is not None:
668
+ progress(1.0, desc="Translation complete!")
669
+
670
+ return True, summary, zip_path, weights_path
671
+
672
+ except Exception as e:
673
+ print(f"Error in translation: {str(e)}")
674
+ import traceback
675
+ traceback.print_exc()
676
+ if progress is not None:
677
+ progress(1.0, desc=f"Error during translation: {str(e)}")
678
+ return True, f"Error during translation: {str(e)}", None, None
679
+
680
+ def create_translator_interface():
681
+ with gr.Blocks() as translator_interface:
682
+ # Header
683
+ gr.Markdown("## 🗺️ GeoBoundary Translator&nbsp;-&nbsp; Translate Statistics into Different Geographies")
684
+
685
+ # Main content in three columns
686
+ with gr.Row():
687
+ # First column - Original Geography
688
+ with gr.Column(variant="panel", scale=1, min_width=300, elem_id="original-column"):
689
+ gr.Markdown("## Original Geography")
690
+ gr.Markdown("*Supported formats: GeoJSON, GeoPackage (.geojson, .json, .gpkg)*")
691
+ original_file = gr.File(
692
+ label="Upload Geographic File",
693
+ file_types=[".geojson", ".json", ".gpkg"]
694
+ )
695
+ original_id = gr.Dropdown(label="Select Unique ID Column", choices=[])
696
+ original_map = gr.Image(label="Map View", type="filepath")
697
+
698
+ # Second column - New Geography
699
+ with gr.Column(variant="panel", scale=1, min_width=300, elem_id="new-column"):
700
+ gr.Markdown("## New Geography")
701
+ gr.Markdown("*Supported formats: GeoJSON, GeoPackage (.geojson, .json, .gpkg)*")
702
+ new_file = gr.File(
703
+ label="Upload Geographic File",
704
+ file_types=[".geojson", ".json", ".gpkg"]
705
+ )
706
+ new_id = gr.Dropdown(label="Select Unique ID Column", choices=[])
707
+ new_map = gr.Image(label="Map View", type="filepath")
708
+
709
+ # Third column - Statistics and Translation
710
+ with gr.Column(variant="panel", scale=1, min_width=300, elem_id="stats-column"):
711
+ gr.Markdown("## Statistics & Translation")
712
+ stats_file = gr.File(label="Upload CSV File with Statistics (optional for weights only)", file_types=[".csv"])
713
+ stats_id = gr.Dropdown(label="Select Unique ID Column", choices=[])
714
+
715
+ # Add component for selecting statistics columns
716
+ stats_cols = gr.CheckboxGroup(label="Select Statistics Columns to Transfer", choices=[], visible=False)
717
+
718
+ # Translation controls
719
+ gr.Markdown("### Translation Controls")
720
+ with gr.Row():
721
+ translate_btn = gr.Button("Translate Statistics", variant="primary", interactive=False)
722
+ calc_weights_btn = gr.Button("Calculate Weights Only", variant="secondary", interactive=False)
723
+
724
+ # Processing indicator - just use one status text
725
+ status_text = gr.Textbox(label="Status", interactive=False, value="Ready")
726
+
727
+ # Placeholder for results
728
+ with gr.Accordion("Results", open=False, visible=False) as results_accordion:
729
+ results_summary = gr.Textbox(label="Summary", lines=5)
730
+
731
+ with gr.Row():
732
+ download_all_btn = gr.Button("Download All Files (ZIP)")
733
+ download_weights_btn = gr.Button("Download Weights Matrix")
734
+
735
+ download_output = gr.File(label="Download", visible=False)
736
+
737
+ # Connect components with their update functions
738
+ def update_orig_dropdown_choices(file_obj, progress=gr.Progress()):
739
+ if file_obj is None:
740
+ return gr.Dropdown(choices=[], value=None), None
741
+ columns, gdf = extract_columns_from_geo_file(file_obj, progress)
742
+ likely_id_cols = [col for col in columns if any(id_term in col.lower() for id_term in ['id', 'code', 'key'])]
743
+ default_value = likely_id_cols[0] if likely_id_cols else None
744
+ return gr.Dropdown(choices=columns, value=default_value), create_map_visualization(gdf, "Original Geography", progress)
745
+
746
+ def update_new_dropdown_choices(file_obj, progress=gr.Progress()):
747
+ if file_obj is None:
748
+ return gr.Dropdown(choices=[], value=None), None
749
+ columns, gdf = extract_columns_from_geo_file(file_obj, progress)
750
+ likely_id_cols = [col for col in columns if any(id_term in col.lower() for id_term in ['id', 'code', 'key'])]
751
+ default_value = likely_id_cols[0] if likely_id_cols else None
752
+ return gr.Dropdown(choices=columns, value=default_value), create_map_visualization(gdf, "New Geography", progress)
753
+
754
+ def update_stats_dropdown_choices(file_obj, progress=gr.Progress()):
755
+ if file_obj is None:
756
+ return gr.Dropdown(choices=[], value=None), gr.CheckboxGroup(choices=[], value=[], visible=False)
757
+ columns, df = extract_columns_from_csv(file_obj, progress)
758
+ likely_id_cols = [col for col in columns if any(id_term in col.lower() for id_term in ['id', 'code', 'key'])]
759
+ default_value = likely_id_cols[0] if likely_id_cols else None
760
+ numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
761
+ stat_cols = [col for col in numeric_cols if col not in likely_id_cols]
762
+ default_selection = stat_cols[:5] if len(stat_cols) > 5 else stat_cols
763
+ return gr.Dropdown(choices=columns, value=default_value), gr.CheckboxGroup(choices=stat_cols, value=default_selection, visible=True)
764
+
765
+ original_file.change(
766
+ fn=update_orig_dropdown_choices,
767
+ inputs=original_file,
768
+ outputs=[original_id, original_map]
769
+ )
770
+
771
+ new_file.change(
772
+ fn=update_new_dropdown_choices,
773
+ inputs=new_file,
774
+ outputs=[new_id, new_map]
775
+ )
776
+
777
+ stats_file.change(
778
+ fn=update_stats_dropdown_choices,
779
+ inputs=stats_file,
780
+ outputs=[stats_id, stats_cols]
781
+ )
782
+
783
+ # Function to check fields and update button status
784
+ def update_button_status(original_file, original_id, new_file, new_id, stats_file, stats_id, stats_cols):
785
+ translation_ready, weights_only_ready, message = check_fields(
786
+ original_file, original_id, new_file, new_id, stats_file, stats_id, stats_cols
787
+ )
788
+ return (
789
+ gr.Button(interactive=translation_ready), # translate_btn
790
+ gr.Button(interactive=weights_only_ready), # calc_weights_btn
791
+ message # status_text
792
+ )
793
+
794
+ # Connect all inputs to update button status
795
+ for component in [original_file, original_id, new_file, new_id, stats_file, stats_id, stats_cols]:
796
+ component.change(
797
+ fn=update_button_status,
798
+ inputs=[original_file, original_id, new_file, new_id, stats_file, stats_id, stats_cols],
799
+ outputs=[translate_btn, calc_weights_btn, status_text]
800
+ )
801
+
802
+ # Handlers for translation and weights calculation
803
+ def translate_statistics_handler(original_file, original_id, new_file, new_id, stats_file, stats_id, stats_cols, progress=gr.Progress()):
804
+ # Call the actual function with progress tracking
805
+ visible, summary, zip_path, weights_path = translate_statistics(
806
+ original_file, original_id, new_file, new_id, stats_file, stats_id, stats_cols,
807
+ progress=progress
808
+ )
809
+
810
+ # Store the file paths for download buttons
811
+ return (
812
+ gr.Accordion(visible=True, open=True), # results_accordion
813
+ summary, # results_summary
814
+ gr.Button(visible=zip_path is not None), # download_all_btn
815
+ gr.Button(visible=weights_path is not None), # download_weights_btn
816
+ zip_path, # State for ZIP
817
+ weights_path, # State for weights
818
+ "Processing complete" # status text
819
+ )
820
+
821
+ def calculate_weights_handler(original_file, original_id, new_file, new_id, progress=gr.Progress()):
822
+ # Call the actual function with progress tracking
823
+ visible, summary, zip_path, weights_path = calculate_weights_only(
824
+ original_file, original_id, new_file, new_id,
825
+ progress=progress
826
+ )
827
+
828
+ # Store the file paths for download buttons
829
+ return (
830
+ gr.Accordion(visible=True, open=True), # results_accordion
831
+ summary, # results_summary
832
+ gr.Button(visible=zip_path is not None), # download_all_btn
833
+ gr.Button(visible=weights_path is not None), # download_weights_btn
834
+ zip_path, # State for ZIP
835
+ weights_path, # State for weights
836
+ "Processing complete" # status text
837
+ )
838
+
839
+ # Connect buttons with pre-click handlers to show processing status
840
+ def show_processing():
841
+ return "Processing started..."
842
+
843
+ translate_btn.click(
844
+ fn=show_processing,
845
+ inputs=[],
846
+ outputs=[status_text],
847
+ queue=False
848
+ ).then(
849
+ fn=translate_statistics_handler,
850
+ inputs=[original_file, original_id, new_file, new_id, stats_file, stats_id, stats_cols],
851
+ outputs=[
852
+ results_accordion,
853
+ results_summary,
854
+ download_all_btn,
855
+ download_weights_btn,
856
+ gr.State(), # For ZIP path
857
+ gr.State(), # For weights path
858
+ status_text
859
+ ]
860
+ )
861
+
862
+ calc_weights_btn.click(
863
+ fn=show_processing,
864
+ inputs=[],
865
+ outputs=[status_text],
866
+ queue=False
867
+ ).then(
868
+ fn=calculate_weights_handler,
869
+ inputs=[original_file, original_id, new_file, new_id],
870
+ outputs=[
871
+ results_accordion,
872
+ results_summary,
873
+ download_all_btn,
874
+ download_weights_btn,
875
+ gr.State(), # For ZIP path
876
+ gr.State(), # For weights path
877
+ status_text
878
+ ]
879
+ )
880
+
881
+ # Handler for download buttons
882
+ def download_zip(zip_path):
883
+ if zip_path:
884
+ return gr.File(value=zip_path, visible=True)
885
+ return gr.File(visible=False)
886
+
887
+ def download_weights(weights_path):
888
+ if weights_path:
889
+ return gr.File(value=weights_path, visible=True)
890
+ return gr.File(visible=False)
891
+
892
+ # Connect download buttons
893
+ download_all_btn.click(
894
+ fn=download_zip,
895
+ inputs=[gr.State()], # ZIP path
896
+ outputs=[download_output]
897
+ )
898
+
899
+ download_weights_btn.click(
900
+ fn=download_weights,
901
+ inputs=[gr.State()], # Weights path
902
+ outputs=[download_output]
903
+ )
904
+
905
+ # CSS for column styling
906
+ translator_interface.load(
907
+ js="""
908
+ function() {
909
+ // Set background colors for columns
910
+ var originalColumn = document.getElementById('original-column');
911
+ var newColumn = document.getElementById('new-column');
912
+ var statsColumn = document.getElementById('stats-column');
913
+
914
+ if (originalColumn) originalColumn.style.backgroundColor = '#f0f8ff'; // Light blue
915
+ if (newColumn) newColumn.style.backgroundColor = '#fff8f0'; // Light orange
916
+ if (statsColumn) statsColumn.style.backgroundColor = '#f0fff0'; // Light green
917
+ }
918
+ """
919
+ )
920
+
921
+ return translator_interface
922
+
923
+ if __name__ == "__main__":
924
+ # This allows the module to be run directly for testing
925
+ app = create_translator_interface()
926
+ app.launch()
spatial_diffusion.py ADDED
@@ -0,0 +1,1059 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import matplotlib.pyplot as plt
4
+ from shapely.geometry import Point, Polygon
5
+ import random
6
+ import datetime
7
+ import gradio as gr
8
+ import tempfile
9
+ import os
10
+ import requests
11
+ import json
12
+ from typing import List, Tuple, Optional, Dict, Any, Union
13
+
14
+ def fetch_osm_exclusion_zones(bounds: Tuple[float, float, float, float], exclusion_types: List[str]) -> Optional[Any]:
15
+ """
16
+ Fetch exclusion zones from OpenStreetMap using Overpass API.
17
+
18
+ Args:
19
+ bounds: (min_lat, min_lon, max_lat, max_lon) bounding box
20
+ exclusion_types: List of exclusion types to fetch
21
+
22
+ Returns:
23
+ GeoDataFrame with exclusion polygons or None if failed
24
+ """
25
+ try:
26
+ import geopandas as gpd
27
+ from shapely.geometry import Polygon, MultiPolygon, LineString
28
+
29
+ # Overpass API endpoint
30
+ overpass_url = "http://overpass-api.de/api/interpreter"
31
+
32
+ # Build Overpass query based on selected exclusion types
33
+ queries = []
34
+
35
+ if "Water bodies" in exclusion_types:
36
+ # Get both water polygons AND linear waterways
37
+ queries.extend([
38
+ # Water area polygons
39
+ f'way["natural"="water"]({bounds[0]},{bounds[1]},{bounds[2]},{bounds[3]});',
40
+ f'relation["natural"="water"]({bounds[0]},{bounds[1]},{bounds[2]},{bounds[3]});',
41
+ f'way["landuse"="reservoir"]({bounds[0]},{bounds[1]},{bounds[2]},{bounds[3]});',
42
+ f'way["water"="lake"]({bounds[0]},{bounds[1]},{bounds[2]},{bounds[3]});',
43
+ f'way["water"="pond"]({bounds[0]},{bounds[1]},{bounds[2]},{bounds[3]});',
44
+ # Linear waterways (rivers, streams, canals)
45
+ f'way["waterway"="river"]({bounds[0]},{bounds[1]},{bounds[2]},{bounds[3]});',
46
+ f'way["waterway"="stream"]({bounds[0]},{bounds[1]},{bounds[2]},{bounds[3]});',
47
+ f'way["waterway"="canal"]({bounds[0]},{bounds[1]},{bounds[2]},{bounds[3]});'
48
+ ])
49
+
50
+ if "Parks & green spaces" in exclusion_types:
51
+ queries.extend([
52
+ f'way["leisure"="park"]({bounds[0]},{bounds[1]},{bounds[2]},{bounds[3]});',
53
+ f'way["landuse"="forest"]({bounds[0]},{bounds[1]},{bounds[2]},{bounds[3]});',
54
+ f'way["landuse"="grass"]({bounds[0]},{bounds[1]},{bounds[2]},{bounds[3]});',
55
+ f'way["natural"="wood"]({bounds[0]},{bounds[1]},{bounds[2]},{bounds[3]});'
56
+ ])
57
+
58
+ if "Industrial areas" in exclusion_types:
59
+ queries.extend([
60
+ f'way["landuse"="industrial"]({bounds[0]},{bounds[1]},{bounds[2]},{bounds[3]});',
61
+ f'way["landuse"="commercial"]({bounds[0]},{bounds[1]},{bounds[2]},{bounds[3]});'
62
+ ])
63
+
64
+ if "Major roads" in exclusion_types:
65
+ queries.extend([
66
+ f'way["highway"~"motorway|trunk|primary"]({bounds[0]},{bounds[1]},{bounds[2]},{bounds[3]});'
67
+ ])
68
+
69
+ if not queries:
70
+ return None
71
+
72
+ # Build complete Overpass query
73
+ overpass_query = f"""
74
+ [out:json][timeout:25];
75
+ (
76
+ {chr(10).join(queries)}
77
+ );
78
+ out geom;
79
+ """
80
+
81
+ print(f"Fetching OSM data for exclusion zones: {exclusion_types}")
82
+
83
+ # Make request to Overpass API
84
+ response = requests.get(overpass_url, params={'data': overpass_query})
85
+ response.raise_for_status()
86
+
87
+ data = response.json()
88
+
89
+ if 'elements' not in data or not data['elements']:
90
+ print("No exclusion zones found in the specified area")
91
+ return None
92
+
93
+ # Convert OSM data to polygons
94
+ polygons = []
95
+ zone_types = []
96
+
97
+ for element in data['elements']:
98
+ try:
99
+ if element['type'] == 'way' and 'geometry' in element:
100
+ tags = element.get('tags', {})
101
+
102
+ # Determine what type of feature this is
103
+ zone_type = None
104
+ if 'natural' in tags and tags['natural'] == 'water':
105
+ zone_type = 'Water'
106
+ elif 'landuse' in tags and tags['landuse'] == 'reservoir':
107
+ zone_type = 'Water'
108
+ elif 'water' in tags:
109
+ zone_type = 'Water'
110
+ elif 'waterway' in tags and tags['waterway'] in ['river', 'stream', 'canal']:
111
+ zone_type = 'Water'
112
+ elif 'leisure' in tags and tags['leisure'] == 'park':
113
+ zone_type = 'Park'
114
+ elif 'landuse' in tags and tags['landuse'] in ['forest', 'grass']:
115
+ zone_type = 'Green space'
116
+ elif 'natural' in tags and tags['natural'] == 'wood':
117
+ zone_type = 'Forest'
118
+ elif 'landuse' in tags and tags['landuse'] in ['industrial', 'commercial']:
119
+ zone_type = 'Industrial/Commercial'
120
+ elif 'highway' in tags:
121
+ zone_type = 'Major road'
122
+
123
+ if zone_type is None:
124
+ continue
125
+
126
+ # Create polygon from way geometry
127
+ coords = [(node['lon'], node['lat']) for node in element['geometry']]
128
+
129
+ # Handle different geometry types
130
+ if 'waterway' in tags or 'highway' in tags:
131
+ # For linear features (rivers, roads), create a buffered polygon from the line
132
+ if len(coords) >= 2:
133
+ try:
134
+ line = LineString(coords)
135
+ # Buffer size depends on feature type
136
+ if 'waterway' in tags:
137
+ if tags['waterway'] == 'river':
138
+ buffer_size = 50 / 111320 # Rivers: ~50 meters
139
+ elif tags['waterway'] == 'canal':
140
+ buffer_size = 30 / 111320 # Canals: ~30 meters
141
+ else: # streams
142
+ buffer_size = 20 / 111320 # Streams: ~20 meters
143
+ else: # highways
144
+ buffer_size = 25 / 111320 # Roads: ~25 meters
145
+
146
+ polygon = line.buffer(buffer_size)
147
+ if polygon.is_valid and polygon.area > 0:
148
+ polygons.append(polygon)
149
+ zone_types.append(zone_type)
150
+ except Exception as e:
151
+ print(f"Error buffering linear feature: {str(e)}")
152
+ continue
153
+ else:
154
+ # For areas, create closed polygons
155
+ if len(coords) > 2:
156
+ # Close polygon if not already closed
157
+ if coords[0] != coords[-1]:
158
+ coords.append(coords[0])
159
+
160
+ if len(coords) >= 4: # Valid polygon needs at least 4 points
161
+ try:
162
+ polygon = Polygon(coords)
163
+ if polygon.is_valid and polygon.area > 0:
164
+ polygons.append(polygon)
165
+ zone_types.append(zone_type)
166
+ except Exception as e:
167
+ print(f"Error creating polygon: {str(e)}")
168
+ continue
169
+
170
+ except Exception as e:
171
+ print(f"Error processing OSM element: {str(e)}")
172
+ continue
173
+
174
+ if not polygons:
175
+ print("No valid polygons found in OSM data")
176
+ return None
177
+
178
+ # Create GeoDataFrame
179
+ gdf = gpd.GeoDataFrame(
180
+ {'zone_type': zone_types},
181
+ geometry=polygons,
182
+ crs='EPSG:4326'
183
+ )
184
+
185
+ print(f"Successfully fetched {len(gdf)} exclusion zones from OpenStreetMap")
186
+ print(f"Zone types found: {gdf['zone_type'].value_counts().to_dict()}")
187
+ return gdf
188
+
189
+ except ImportError:
190
+ print("GeoPandas not available for OSM processing")
191
+ return None
192
+ except requests.exceptions.RequestException as e:
193
+ print(f"Error fetching data from OpenStreetMap: {str(e)}")
194
+ return None
195
+ except Exception as e:
196
+ print(f"Error processing OpenStreetMap data: {str(e)}")
197
+ return None
198
+
199
+ def calculate_bounds_from_points(input_df: pd.DataFrame, buffer_km: float = 2.0) -> Tuple[float, float, float, float]:
200
+ """Calculate bounding box around input points with buffer"""
201
+ # Get min/max coordinates
202
+ min_lat = input_df['lat'].min()
203
+ max_lat = input_df['lat'].max()
204
+ min_lon = input_df['lon'].min()
205
+ max_lon = input_df['lon'].max()
206
+
207
+ # Add buffer (approximate conversion from km to degrees)
208
+ buffer_deg = buffer_km / 111.0 # Rough conversion: 1 degree ≈ 111 km
209
+
210
+ return (
211
+ min_lat - buffer_deg, # min_lat
212
+ min_lon - buffer_deg, # min_lon
213
+ max_lat + buffer_deg, # max_lat
214
+ max_lon + buffer_deg # max_lon
215
+ )
216
+
217
+ class SpatialDiffuser:
218
+ """
219
+ Class for performing spatial diffusion - takes points with counts and diffuses them
220
+ according to specified distributions within given radii, with optional exclusion zones.
221
+ """
222
+
223
+ def __init__(self):
224
+ self.distribution_methods = {
225
+ "uniform": self._uniform_distribution,
226
+ "normal": self._normal_distribution,
227
+ "exponential_decay": self._exponential_decay,
228
+ "distance_weighted": self._distance_weighted
229
+ }
230
+
231
+ def diffuse_points(self,
232
+ input_data: pd.DataFrame,
233
+ distribution_type: str = "uniform",
234
+ global_radius: Optional[float] = None,
235
+ time_start: Optional[datetime.datetime] = None,
236
+ time_end: Optional[datetime.datetime] = None,
237
+ seed: Optional[int] = None,
238
+ exclusion_zones_gdf: Optional[Any] = None) -> pd.DataFrame:
239
+ """
240
+ Generate diffused points based on input coordinates and counts.
241
+
242
+ Args:
243
+ input_data: DataFrame with columns: lat, lon, count, radius (optional)
244
+ distribution_type: Type of spatial distribution to use
245
+ global_radius: Radius to use for all points if not specified individually (in meters)
246
+ time_start: Start time for temporal distribution
247
+ time_end: End time for temporal distribution
248
+ seed: Random seed for reproducible results
249
+ exclusion_zones_gdf: GeoDataFrame with polygons to exclude points from
250
+
251
+ Returns:
252
+ DataFrame with columns: lat, lon, source_id, timestamp (if temporal)
253
+ """
254
+ # Set random seed if provided
255
+ if seed is not None:
256
+ np.random.seed(seed)
257
+ random.seed(seed)
258
+
259
+ if distribution_type not in self.distribution_methods:
260
+ raise ValueError(f"Distribution type '{distribution_type}' not supported. Choose from: {list(self.distribution_methods.keys())}")
261
+
262
+ # Initialize list to hold all generated points
263
+ all_points = []
264
+
265
+ # Generate points for each input location
266
+ for idx, row in input_data.iterrows():
267
+ # Get radius (either from row or global)
268
+ radius = row.get('radius', global_radius)
269
+ if radius is None:
270
+ raise ValueError("Radius must be specified either globally or per point")
271
+
272
+ # Get count
273
+ count = int(row['count'])
274
+ if count <= 0:
275
+ continue
276
+
277
+ # Generate points with exclusion zone filtering
278
+ new_points = self._generate_points_with_exclusions(
279
+ lat=row['lat'],
280
+ lon=row['lon'],
281
+ count=count,
282
+ radius=radius,
283
+ distribution_type=distribution_type,
284
+ exclusion_zones_gdf=exclusion_zones_gdf
285
+ )
286
+
287
+ # Add source identifier
288
+ source_ids = [idx] * len(new_points)
289
+
290
+ # Add timestamps if temporal distribution is requested
291
+ if time_start is not None and time_end is not None:
292
+ timestamps = self._generate_timestamps(len(new_points), time_start, time_end)
293
+
294
+ # Combine points with metadata
295
+ for i, point in enumerate(new_points):
296
+ all_points.append({
297
+ 'lat': point[0],
298
+ 'lon': point[1],
299
+ 'source_id': source_ids[i],
300
+ 'timestamp': timestamps[i]
301
+ })
302
+ else:
303
+ # Combine points with metadata without timestamps
304
+ for i, point in enumerate(new_points):
305
+ all_points.append({
306
+ 'lat': point[0],
307
+ 'lon': point[1],
308
+ 'source_id': source_ids[i]
309
+ })
310
+
311
+ # Convert to DataFrame
312
+ result = pd.DataFrame(all_points)
313
+ return result
314
+
315
+ def _generate_points_with_exclusions(self, lat: float, lon: float, count: int, radius: float,
316
+ distribution_type: str, exclusion_zones_gdf: Optional[Any] = None) -> List[Tuple[float, float]]:
317
+ """Generate points while avoiding exclusion zones"""
318
+
319
+ if exclusion_zones_gdf is None or len(exclusion_zones_gdf) == 0:
320
+ # No exclusion zones, use normal generation
321
+ return self.distribution_methods[distribution_type](lat, lon, count, radius)
322
+
323
+ try:
324
+ import geopandas as gpd
325
+ from shapely.geometry import Point
326
+
327
+ valid_points = []
328
+ max_attempts = count * 10 # Generate up to 10x more points to account for exclusions
329
+ attempts = 0
330
+
331
+ # Ensure exclusion zones are in the same CRS as our points (WGS84)
332
+ if exclusion_zones_gdf.crs is None:
333
+ exclusion_zones_gdf = exclusion_zones_gdf.set_crs('EPSG:4326')
334
+ elif exclusion_zones_gdf.crs != 'EPSG:4326':
335
+ exclusion_zones_gdf = exclusion_zones_gdf.to_crs('EPSG:4326')
336
+
337
+ while len(valid_points) < count and attempts < max_attempts:
338
+ # Generate a batch of points
339
+ batch_size = min(count * 2, max_attempts - attempts)
340
+ candidate_points = self.distribution_methods[distribution_type](
341
+ lat, lon, batch_size, radius
342
+ )
343
+
344
+ # Check each point against exclusion zones
345
+ for point in candidate_points:
346
+ if len(valid_points) >= count:
347
+ break
348
+
349
+ point_geom = Point(point[1], point[0]) # lon, lat for Point
350
+
351
+ # Check if point intersects with any exclusion zone
352
+ is_excluded = False
353
+ for _, exclusion_zone in exclusion_zones_gdf.iterrows():
354
+ if point_geom.intersects(exclusion_zone.geometry):
355
+ is_excluded = True
356
+ break
357
+
358
+ if not is_excluded:
359
+ valid_points.append(point)
360
+
361
+ attempts += batch_size
362
+
363
+ # If we couldn't generate enough valid points, warn the user
364
+ if len(valid_points) < count:
365
+ print(f"Warning: Could only generate {len(valid_points)} valid points out of {count} requested for location ({lat}, {lon}). Exclusion zones may be too restrictive.")
366
+
367
+ return valid_points
368
+
369
+ except ImportError:
370
+ print("GeoPandas not available for exclusion zone processing. Generating points without exclusions.")
371
+ return self.distribution_methods[distribution_type](lat, lon, count, radius)
372
+ except Exception as e:
373
+ print(f"Error processing exclusion zones: {str(e)}. Generating points without exclusions.")
374
+ return self.distribution_methods[distribution_type](lat, lon, count, radius)
375
+
376
+ def _uniform_distribution(self, lat: float, lon: float, count: int, radius: float) -> List[Tuple[float, float]]:
377
+ """Generate points uniformly distributed within a circle"""
378
+ points = []
379
+
380
+ for _ in range(count):
381
+ # Generate a random angle and distance
382
+ angle = random.uniform(0, 2 * np.pi)
383
+ # Uniform distribution needs square root to avoid clustering in center
384
+ r = radius * np.sqrt(random.uniform(0, 1))
385
+
386
+ # Convert polar coordinates to Cartesian
387
+ x = r * np.cos(angle)
388
+ y = r * np.sin(angle)
389
+
390
+ # Convert meters to approximate degrees (this is a simplification)
391
+ # A more accurate implementation would use proper geographic projections
392
+ lat_offset = y / 111320 # 1 degree latitude is approximately 111320 meters
393
+ # Longitude degrees vary with latitude, so adjust accordingly
394
+ lon_offset = x / (111320 * np.cos(np.radians(lat)))
395
+
396
+ new_lat = lat + lat_offset
397
+ new_lon = lon + lon_offset
398
+
399
+ points.append((new_lat, new_lon))
400
+
401
+ return points
402
+
403
+ def _normal_distribution(self, lat: float, lon: float, count: int, radius: float) -> List[Tuple[float, float]]:
404
+ """Generate points with normal distribution (more concentrated near center)"""
405
+ points = []
406
+
407
+ # Standard deviation as a fraction of radius
408
+ std_dev = radius / 3 # 3 sigma rule - 99.7% of points within radius
409
+
410
+ for _ in range(count):
411
+ # Generate points using normal distribution
412
+ while True:
413
+ # Generate x and y offsets using normal distribution
414
+ x = np.random.normal(0, std_dev)
415
+ y = np.random.normal(0, std_dev)
416
+
417
+ # Calculate distance from center
418
+ distance = np.sqrt(x**2 + y**2)
419
+
420
+ # If point is within radius, keep it
421
+ if distance <= radius:
422
+ break
423
+
424
+ # Convert meters to approximate degrees
425
+ lat_offset = y / 111320
426
+ lon_offset = x / (111320 * np.cos(np.radians(lat)))
427
+
428
+ new_lat = lat + lat_offset
429
+ new_lon = lon + lon_offset
430
+
431
+ points.append((new_lat, new_lon))
432
+
433
+ return points
434
+
435
+ def _exponential_decay(self, lat: float, lon: float, count: int, radius: float) -> List[Tuple[float, float]]:
436
+ """Generate points with exponential decay from center"""
437
+ points = []
438
+
439
+ # Rate parameter - controls how quickly density decreases with distance
440
+ rate = 3.0 / radius # Higher value = steeper decay
441
+
442
+ for _ in range(count):
443
+ # Generate random angle
444
+ angle = random.uniform(0, 2 * np.pi)
445
+
446
+ # Generate distance with exponential distribution
447
+ # Use rejection sampling to ensure points are within radius
448
+ while True:
449
+ # Generate exponential random variable
450
+ r = random.expovariate(rate)
451
+ if r <= radius:
452
+ break
453
+
454
+ # Convert polar coordinates to Cartesian
455
+ x = r * np.cos(angle)
456
+ y = r * np.sin(angle)
457
+
458
+ # Convert meters to approximate degrees
459
+ lat_offset = y / 111320
460
+ lon_offset = x / (111320 * np.cos(np.radians(lat)))
461
+
462
+ new_lat = lat + lat_offset
463
+ new_lon = lon + lon_offset
464
+
465
+ points.append((new_lat, new_lon))
466
+
467
+ return points
468
+
469
+ def _distance_weighted(self, lat: float, lon: float, count: int, radius: float) -> List[Tuple[float, float]]:
470
+ """
471
+ Generate points with a custom distance-weighted distribution
472
+ (more points at medium distances than at center or edge)
473
+ """
474
+ points = []
475
+
476
+ for _ in range(count):
477
+ # Generate random angle
478
+ angle = random.uniform(0, 2 * np.pi)
479
+
480
+ # Custom distribution - more weight at middle distances
481
+ # Generate r² with beta distribution (concentration in middle)
482
+ r_squared = random.betavariate(2, 2) # Beta with alpha=beta=2 has peak in middle
483
+ r = np.sqrt(r_squared) * radius
484
+
485
+ # Convert polar coordinates to Cartesian
486
+ x = r * np.cos(angle)
487
+ y = r * np.sin(angle)
488
+
489
+ # Convert meters to approximate degrees
490
+ lat_offset = y / 111320
491
+ lon_offset = x / (111320 * np.cos(np.radians(lat)))
492
+
493
+ new_lat = lat + lat_offset
494
+ new_lon = lon + lon_offset
495
+
496
+ points.append((new_lat, new_lon))
497
+
498
+ return points
499
+
500
+ def _generate_timestamps(self, count: int, start_time: datetime.datetime, end_time: datetime.datetime) -> List[datetime.datetime]:
501
+ """Generate uniformly distributed timestamps"""
502
+ timestamps = []
503
+
504
+ # Convert to timestamps for easier calculations
505
+ start_ts = start_time.timestamp()
506
+ end_ts = end_time.timestamp()
507
+
508
+ for _ in range(count):
509
+ # Generate a random timestamp between start and end
510
+ random_ts = random.uniform(start_ts, end_ts)
511
+ timestamp = datetime.datetime.fromtimestamp(random_ts)
512
+ timestamps.append(timestamp)
513
+
514
+ # Sort timestamps chronologically
515
+ timestamps.sort()
516
+
517
+ return timestamps
518
+
519
+ def create_visualization(input_df, output_df, show_basemap=False, exclusion_zones_gdf=None):
520
+ """Create visualization of input and diffused points"""
521
+ fig, ax = plt.subplots(figsize=(12, 10))
522
+
523
+ # Set background color
524
+ fig.patch.set_facecolor('white')
525
+ ax.set_facecolor('#f8f9fa')
526
+
527
+ # Define colors for different exclusion zone types
528
+ exclusion_colors = {
529
+ 'Water': '#4FC3F7', # Light blue
530
+ 'Park': '#66BB6A', # Green
531
+ 'Green space': '#81C784', # Light green
532
+ 'Forest': '#4CAF50', # Dark green
533
+ 'Industrial/Commercial': '#90A4AE', # Grey
534
+ 'Major road': '#FFD54F', # Yellow
535
+ 'Other': '#FFAB91' # Light orange
536
+ }
537
+
538
+ # If basemap is requested, convert to Web Mercator and add basemap
539
+ if show_basemap:
540
+ try:
541
+ import contextily as ctx
542
+ import geopandas as gpd
543
+ from shapely.geometry import Point
544
+
545
+ # Create GeoDataFrames for proper projection
546
+ input_gdf = gpd.GeoDataFrame(
547
+ input_df,
548
+ geometry=[Point(lon, lat) for lon, lat in zip(input_df['lon'], input_df['lat'])],
549
+ crs='EPSG:4326'
550
+ )
551
+ output_gdf = gpd.GeoDataFrame(
552
+ output_df,
553
+ geometry=[Point(lon, lat) for lon, lat in zip(output_df['lon'], output_df['lat'])],
554
+ crs='EPSG:4326'
555
+ )
556
+
557
+ # Convert to Web Mercator for basemap compatibility
558
+ input_gdf_merc = input_gdf.to_crs('EPSG:3857')
559
+ output_gdf_merc = output_gdf.to_crs('EPSG:3857')
560
+
561
+ # Plot exclusion zones first (if provided) with color coding
562
+ if exclusion_zones_gdf is not None and len(exclusion_zones_gdf) > 0:
563
+ try:
564
+ exclusion_zones_merc = exclusion_zones_gdf.to_crs('EPSG:3857')
565
+
566
+ # Group by zone type and plot with appropriate colors
567
+ plotted_types = set()
568
+ for zone_type in exclusion_zones_merc['zone_type'].unique():
569
+ zone_subset = exclusion_zones_merc[exclusion_zones_merc['zone_type'] == zone_type]
570
+ color = exclusion_colors.get(zone_type, exclusion_colors['Other'])
571
+
572
+ # Only add label for first occurrence of each type
573
+ label = zone_type if zone_type not in plotted_types else None
574
+ if label:
575
+ plotted_types.add(zone_type)
576
+
577
+ zone_subset.plot(ax=ax, color=color, alpha=0.6, edgecolor='white',
578
+ linewidth=0.5, label=label)
579
+
580
+ except Exception as e:
581
+ print(f"Error plotting exclusion zones: {str(e)}")
582
+
583
+ # Extract coordinates for plotting
584
+ input_x = input_gdf_merc.geometry.x
585
+ input_y = input_gdf_merc.geometry.y
586
+ output_x = output_gdf_merc.geometry.x
587
+ output_y = output_gdf_merc.geometry.y
588
+
589
+ # Plot diffused points first (so they appear behind source points)
590
+ ax.scatter(output_x, output_y,
591
+ alpha=0.7, color='#FF9800', s=12, label=f'Generated Points (n={len(output_df)})',
592
+ edgecolors='white', linewidth=0.3)
593
+
594
+ # Draw radius circles first (so they appear behind everything else)
595
+ for idx, row in input_df.iterrows():
596
+ radius = row.get('radius', None)
597
+
598
+ if radius is not None:
599
+ # Convert center point to Web Mercator
600
+ center_point = gpd.GeoDataFrame(
601
+ [1], geometry=[Point(row['lon'], row['lat'])], crs='EPSG:4326'
602
+ ).to_crs('EPSG:3857')
603
+
604
+ center_x = center_point.geometry.x.iloc[0]
605
+ center_y = center_point.geometry.y.iloc[0]
606
+
607
+ # Draw circle (radius is already in meters, which matches Web Mercator units)
608
+ circle = plt.Circle((center_x, center_y), radius,
609
+ fill=False, color='#9C27B0', linestyle='--',
610
+ alpha=0.5, linewidth=2)
611
+ ax.add_patch(circle)
612
+
613
+ # Plot source points with circles sized by count
614
+ min_size = 100
615
+ max_size = 800
616
+ if len(input_df) > 1:
617
+ size_range = input_df['count'].max() - input_df['count'].min()
618
+ if size_range > 0:
619
+ sizes = min_size + (input_df['count'] - input_df['count'].min()) / size_range * (max_size - min_size)
620
+ else:
621
+ sizes = [min_size] * len(input_df)
622
+ else:
623
+ sizes = [max_size]
624
+
625
+ # Plot source points in purple
626
+ ax.scatter(input_x, input_y,
627
+ s=sizes, c='#9C27B0', alpha=0.9,
628
+ edgecolors='white', linewidth=2,
629
+ label='Source Points (size = count)', zorder=5)
630
+
631
+ # Add count labels next to source points
632
+ for idx, row in input_df.iterrows():
633
+ point_merc = gpd.GeoDataFrame(
634
+ [1], geometry=[Point(row['lon'], row['lat'])], crs='EPSG:4326'
635
+ ).to_crs('EPSG:3857')
636
+
637
+ x_merc = point_merc.geometry.x.iloc[0]
638
+ y_merc = point_merc.geometry.y.iloc[0]
639
+
640
+ ax.annotate(f'{int(row["count"])}',
641
+ (x_merc, y_merc),
642
+ xytext=(8, 8), textcoords='offset points',
643
+ fontsize=10, fontweight='bold', color='white',
644
+ bbox=dict(boxstyle='round,pad=0.3', facecolor='#9C27B0', alpha=0.8),
645
+ zorder=6)
646
+
647
+ # Add basemap
648
+ try:
649
+ ctx.add_basemap(ax, crs='EPSG:3857', source=ctx.providers.CartoDB.Positron, alpha=0.8)
650
+ basemap_added = True
651
+ except Exception as e:
652
+ print(f"Could not add basemap: {str(e)}")
653
+ basemap_added = False
654
+
655
+ # Set axis labels for Web Mercator
656
+ ax.set_xlabel('Easting (Web Mercator)', fontsize=12)
657
+ ax.set_ylabel('Northing (Web Mercator)', fontsize=12)
658
+
659
+ # Use projected coordinates for limits
660
+ x_coords = list(input_x) + list(output_x)
661
+ y_coords = list(input_y) + list(output_y)
662
+
663
+ except ImportError:
664
+ print("Contextily not available for basemap. Falling back to simple plot.")
665
+ show_basemap = False
666
+ except Exception as e:
667
+ print(f"Error creating basemap: {str(e)}. Falling back to simple plot.")
668
+ show_basemap = False
669
+
670
+ # Fallback to simple plot if basemap fails or is not requested
671
+ if not show_basemap:
672
+ # Plot exclusion zones first (if provided) with color coding
673
+ if exclusion_zones_gdf is not None and len(exclusion_zones_gdf) > 0:
674
+ try:
675
+ # Ensure exclusion zones are in WGS84
676
+ if exclusion_zones_gdf.crs != 'EPSG:4326':
677
+ exclusion_zones_gdf = exclusion_zones_gdf.to_crs('EPSG:4326')
678
+
679
+ # Plot zones by type with appropriate colors
680
+ plotted_types = set()
681
+ for idx, zone in exclusion_zones_gdf.iterrows():
682
+ zone_type = zone.get('zone_type', 'Other')
683
+ color = exclusion_colors.get(zone_type, exclusion_colors['Other'])
684
+
685
+ # Only add label for first occurrence of each type
686
+ label = zone_type if zone_type not in plotted_types else None
687
+ if label:
688
+ plotted_types.add(zone_type)
689
+
690
+ if zone.geometry.geom_type == 'Polygon':
691
+ x, y = zone.geometry.exterior.xy
692
+ ax.fill(x, y, color=color, alpha=0.6, edgecolor='white',
693
+ linewidth=0.5, label=label)
694
+ elif zone.geometry.geom_type == 'MultiPolygon':
695
+ for poly in zone.geometry.geoms:
696
+ x, y = poly.exterior.xy
697
+ ax.fill(x, y, color=color, alpha=0.6, edgecolor='white',
698
+ linewidth=0.5, label=label)
699
+ label = None # Only label the first polygon
700
+
701
+ except Exception as e:
702
+ print(f"Error plotting exclusion zones: {str(e)}")
703
+
704
+ # Plot diffused points first (so they appear behind source points) - orange
705
+ ax.scatter(output_df['lon'], output_df['lat'],
706
+ alpha=0.7, color='#FF9800', s=12, label=f'Generated Points (n={len(output_df)})',
707
+ edgecolors='white', linewidth=0.3)
708
+
709
+ # Draw radius circles first (so they appear behind everything else) - purple
710
+ for idx, row in input_df.iterrows():
711
+ radius = row.get('radius', None)
712
+
713
+ if radius is not None:
714
+ # Approximate conversion from meters to degrees
715
+ radius_deg_lat = radius / 111320
716
+ radius_deg_lon = radius / (111320 * np.cos(np.radians(row['lat'])))
717
+
718
+ # Use the average as an approximation
719
+ radius_deg = (radius_deg_lat + radius_deg_lon) / 2
720
+
721
+ # Draw circle in purple
722
+ circle = plt.Circle((row['lon'], row['lat']), radius_deg,
723
+ fill=False, color='#9C27B0', linestyle='--',
724
+ alpha=0.5, linewidth=2)
725
+ ax.add_patch(circle)
726
+
727
+ # Plot source points with circles sized by count - purple
728
+ min_size = 100
729
+ max_size = 800
730
+ if len(input_df) > 1:
731
+ size_range = input_df['count'].max() - input_df['count'].min()
732
+ if size_range > 0:
733
+ sizes = min_size + (input_df['count'] - input_df['count'].min()) / size_range * (max_size - min_size)
734
+ else:
735
+ sizes = [min_size] * len(input_df)
736
+ else:
737
+ sizes = [max_size]
738
+
739
+ # Plot source points in purple
740
+ ax.scatter(input_df['lon'], input_df['lat'],
741
+ s=sizes, c='#9C27B0', alpha=0.9,
742
+ edgecolors='white', linewidth=2,
743
+ label='Source Points (size = count)', zorder=5)
744
+
745
+ # Add count labels next to source points with purple background
746
+ for idx, row in input_df.iterrows():
747
+ ax.annotate(f'{int(row["count"])}',
748
+ (row['lon'], row['lat']),
749
+ xytext=(8, 8), textcoords='offset points',
750
+ fontsize=10, fontweight='bold', color='white',
751
+ bbox=dict(boxstyle='round,pad=0.3', facecolor='#9C27B0', alpha=0.8),
752
+ zorder=6)
753
+
754
+ # Standard coordinate labels
755
+ ax.set_xlabel('Longitude', fontsize=12)
756
+ ax.set_ylabel('Latitude', fontsize=12)
757
+
758
+ # Use original coordinates for limits
759
+ x_coords = list(input_df['lon']) + list(output_df['lon'])
760
+ y_coords = list(input_df['lat']) + list(output_df['lat'])
761
+
762
+ # Common styling
763
+ title = 'Spatial Diffusion Results'
764
+ if show_basemap:
765
+ title += ' (with Basemap)'
766
+ if exclusion_zones_gdf is not None and len(exclusion_zones_gdf) > 0:
767
+ title += ' - Exclusion Zones Applied'
768
+ subtitle = 'Purple source points sized by count, orange generated points, dashed circles show diffusion radius'
769
+
770
+ ax.set_title(f'{title}\n{subtitle}',
771
+ fontsize=14, fontweight='bold', pad=20)
772
+
773
+ # Legend with better positioning
774
+ legend = ax.legend(loc='upper right', bbox_to_anchor=(1, 1),
775
+ frameon=True, fancybox=True, shadow=True)
776
+ legend.get_frame().set_facecolor('white')
777
+ legend.get_frame().set_alpha(0.9)
778
+
779
+ # Add grid (lighter for basemap)
780
+ grid_alpha = 0.2 if show_basemap else 0.3
781
+ ax.grid(True, alpha=grid_alpha, linestyle='-', linewidth=0.5)
782
+
783
+ # Make equal aspect ratio
784
+ ax.set_aspect('equal', 'box')
785
+
786
+ # Add some padding around the data
787
+ x_margin = (max(x_coords) - min(x_coords)) * 0.1
788
+ y_margin = (max(y_coords) - min(y_coords)) * 0.1
789
+
790
+ if x_margin == 0: # Handle case where all points have same x-coordinate
791
+ x_margin = 1000 if show_basemap else 0.01
792
+ if y_margin == 0: # Handle case where all points have same y-coordinate
793
+ y_margin = 1000 if show_basemap else 0.01
794
+
795
+ ax.set_xlim(min(x_coords) - x_margin, max(x_coords) + x_margin)
796
+ ax.set_ylim(min(y_coords) - y_margin, max(y_coords) + y_margin)
797
+
798
+ # Tight layout
799
+ plt.tight_layout()
800
+
801
+ return fig
802
+
803
+ def process_csv(file_obj, distribution_type, global_radius, show_basemap, auto_exclusions, exclusion_file, include_time, time_start, time_end, seed):
804
+ """Process input CSV and generate diffused points"""
805
+ try:
806
+ # Read input CSV
807
+ df = pd.read_csv(file_obj.name)
808
+
809
+ # Validate required columns
810
+ required_cols = ['lat', 'lon', 'count']
811
+ if not all(col in df.columns for col in required_cols):
812
+ return None, f"Error: CSV must contain columns: {', '.join(required_cols)}"
813
+
814
+ # Convert global_radius to float if provided
815
+ if global_radius and global_radius.strip():
816
+ try:
817
+ global_radius = float(global_radius)
818
+ except ValueError:
819
+ return None, "Error: Global radius must be a number"
820
+ else:
821
+ global_radius = None
822
+ # If global radius not provided, check for radius column
823
+ if 'radius' not in df.columns:
824
+ return None, "Error: Either provide a global radius or include a 'radius' column in the CSV"
825
+
826
+ # Convert seed to int if provided
827
+ if seed and seed.strip():
828
+ try:
829
+ seed = int(seed)
830
+ except ValueError:
831
+ return None, "Error: Seed must be an integer"
832
+ else:
833
+ seed = None
834
+
835
+ # Process exclusion zones
836
+ exclusion_zones_gdf = None
837
+
838
+ # First, try manual file upload (takes priority)
839
+ if exclusion_file is not None:
840
+ try:
841
+ import geopandas as gpd
842
+
843
+ # Determine file type and read accordingly
844
+ file_extension = os.path.splitext(exclusion_file.name)[1].lower()
845
+
846
+ if file_extension in ['.geojson', '.json']:
847
+ exclusion_zones_gdf = gpd.read_file(exclusion_file.name)
848
+ elif file_extension == '.gpkg':
849
+ exclusion_zones_gdf = gpd.read_file(exclusion_file.name)
850
+ elif file_extension == '.shp':
851
+ exclusion_zones_gdf = gpd.read_file(exclusion_file.name)
852
+ else:
853
+ return None, f"Error: Unsupported exclusion zone file format: {file_extension}"
854
+
855
+ # Ensure CRS is set
856
+ if exclusion_zones_gdf.crs is None:
857
+ exclusion_zones_gdf = exclusion_zones_gdf.set_crs('EPSG:4326')
858
+
859
+ print(f"Loaded {len(exclusion_zones_gdf)} custom exclusion zones from {exclusion_file.name}")
860
+
861
+ except ImportError:
862
+ return None, "Error: GeoPandas required for exclusion zones processing"
863
+ except Exception as e:
864
+ return None, f"Error reading exclusion zones file: {str(e)}"
865
+
866
+ # If no manual file, try automatic exclusions from OpenStreetMap
867
+ elif auto_exclusions and len(auto_exclusions) > 0:
868
+ try:
869
+ # Calculate bounds around input points
870
+ bounds = calculate_bounds_from_points(df)
871
+ print(f"Fetching automatic exclusions for bounds: {bounds}")
872
+
873
+ # Fetch OSM data
874
+ exclusion_zones_gdf = fetch_osm_exclusion_zones(bounds, auto_exclusions)
875
+
876
+ if exclusion_zones_gdf is not None:
877
+ print(f"Fetched {len(exclusion_zones_gdf)} exclusion zones from OpenStreetMap")
878
+ else:
879
+ print("No exclusion zones found in OpenStreetMap for this area")
880
+
881
+ except Exception as e:
882
+ print(f"Warning: Could not fetch automatic exclusions: {str(e)}")
883
+ # Continue without exclusions rather than failing completely
884
+ exclusion_zones_gdf = None
885
+
886
+ # Process time if requested
887
+ if include_time:
888
+ if not time_start or not time_end:
889
+ return None, "Error: If time distribution is enabled, both start and end times must be provided"
890
+ try:
891
+ time_start_dt = datetime.datetime.strptime(time_start, "%Y-%m-%d %H:%M:%S")
892
+ time_end_dt = datetime.datetime.strptime(time_end, "%Y-%m-%d %H:%M:%S")
893
+ if time_start_dt >= time_end_dt:
894
+ return None, "Error: End time must be after start time"
895
+ except ValueError:
896
+ return None, "Error: Invalid time format. Use YYYY-MM-DD HH:MM:SS"
897
+ else:
898
+ time_start_dt = None
899
+ time_end_dt = None
900
+
901
+ # Create diffuser and generate diffused points
902
+ diffuser = SpatialDiffuser()
903
+ result_df = diffuser.diffuse_points(
904
+ input_data=df,
905
+ distribution_type=distribution_type,
906
+ global_radius=global_radius,
907
+ time_start=time_start_dt,
908
+ time_end=time_end_dt,
909
+ seed=seed,
910
+ exclusion_zones_gdf=exclusion_zones_gdf
911
+ )
912
+
913
+ # Create temporary CSV file
914
+ temp_file = "diffused_points.csv"
915
+ result_df.to_csv(temp_file, index=False)
916
+
917
+ # Create visualization with basemap and exclusion zones
918
+ fig = create_visualization(df, result_df, show_basemap, exclusion_zones_gdf)
919
+
920
+ return fig, temp_file
921
+
922
+ except Exception as e:
923
+ return None, f"Error: {str(e)}"
924
+
925
+ def create_diffusion_interface():
926
+ """Create Gradio interface for the spatial diffusion tool"""
927
+
928
+ with gr.Blocks() as diffusion_interface:
929
+ gr.Markdown("## 🗺️ Spatial Diffusion Tool")
930
+
931
+ with gr.Row():
932
+ with gr.Column(scale=1):
933
+ # Move description into the left column for better space usage
934
+ gr.Markdown("""
935
+ ### About This Tool
936
+ Transform aggregated geographic points with counts into individual points using spatial diffusion methods.
937
+
938
+ **Input CSV Format:**
939
+ - `lat`: Latitude of source point
940
+ - `lon`: Longitude of source point
941
+ - `count`: Number of points to generate
942
+ - `radius`: (Optional) Diffusion radius in meters
943
+
944
+ **Distribution Types:**
945
+ - **Uniform**: Equal probability throughout circle
946
+ - **Normal**: Higher density near center
947
+ - **Exponential Decay**: Density decreases from center
948
+ - **Distance-Weighted**: More points at medium distances
949
+ """)
950
+
951
+ # Input controls
952
+ input_file = gr.File(label="Input CSV File", file_types=[".csv"])
953
+
954
+ # Distribution options grouped together
955
+ gr.Markdown("### 🎯 Distribution Options")
956
+ with gr.Row():
957
+ distribution = gr.Dropdown(
958
+ choices=["uniform", "normal", "exponential_decay", "distance_weighted"],
959
+ value="uniform",
960
+ label="Distribution Type",
961
+ scale=2
962
+ )
963
+ seed = gr.Textbox(
964
+ label="Random Seed (optional)",
965
+ placeholder="e.g. 42",
966
+ scale=1
967
+ )
968
+
969
+ global_radius = gr.Textbox(
970
+ label="Global Radius (meters)",
971
+ placeholder="Only if radius column not in CSV"
972
+ )
973
+
974
+ # Temporal controls in distribution section
975
+ with gr.Accordion("⏰ Temporal Distribution (Optional)", open=False):
976
+ include_time = gr.Checkbox(label="Enable Temporal Distribution", value=False)
977
+ with gr.Group() as time_group:
978
+ time_start = gr.Textbox(
979
+ label="Start Time",
980
+ placeholder="YYYY-MM-DD HH:MM:SS"
981
+ )
982
+ time_end = gr.Textbox(
983
+ label="End Time",
984
+ placeholder="YYYY-MM-DD HH:MM:SS"
985
+ )
986
+
987
+ # Map and exclusion options grouped together
988
+ gr.Markdown("### 🗺️ Map & Exclusion Options")
989
+ show_basemap = gr.Checkbox(
990
+ label="Show underlying map (requires internet)",
991
+ value=False
992
+ )
993
+ gr.Markdown("*Adds geographic context with street/satellite imagery*")
994
+
995
+ # Automatic exclusion zones - no default selection
996
+ auto_exclusions = gr.CheckboxGroup(
997
+ label="Auto-exclude from OpenStreetMap:",
998
+ choices=["Water bodies", "Parks & green spaces", "Industrial areas", "Major roads"],
999
+ value=[] # No default selections
1000
+ )
1001
+
1002
+ # Advanced manual exclusion zones
1003
+ with gr.Accordion("🔧 Advanced: Custom Exclusion Zones", open=False):
1004
+ exclusion_file = gr.File(
1005
+ label="Upload custom shapefile (optional)",
1006
+ file_types=[".geojson", ".json", ".gpkg", ".shp"]
1007
+ )
1008
+ gr.Markdown("*Overrides automatic exclusions if provided*")
1009
+
1010
+ process_btn = gr.Button(
1011
+ "🎯 Generate Diffused Points",
1012
+ variant="primary",
1013
+ size="lg"
1014
+ )
1015
+
1016
+ with gr.Column(scale=2):
1017
+ # Give more space to visualization
1018
+ plot_output = gr.Plot(
1019
+ label="📍 Spatial Diffusion Visualization",
1020
+ show_label=True
1021
+ )
1022
+
1023
+ with gr.Row():
1024
+ with gr.Column(scale=2):
1025
+ file_output = gr.File(label="📥 Download Generated Points")
1026
+ with gr.Column(scale=1):
1027
+ gr.Markdown(
1028
+ """
1029
+ **Legend:**
1030
+ 🟣 Source points (sized by count)
1031
+ 🟠 Generated points
1032
+ ⭕ Diffusion radius
1033
+ 🟦 Water bodies
1034
+ 🟢 Parks & green spaces
1035
+ ⬜ Industrial areas
1036
+ 🟡 Major roads
1037
+ """
1038
+ )
1039
+
1040
+ # Set up event handlers
1041
+ process_btn.click(
1042
+ fn=process_csv,
1043
+ inputs=[input_file, distribution, global_radius, show_basemap, auto_exclusions, exclusion_file, include_time, time_start, time_end, seed],
1044
+ outputs=[plot_output, file_output]
1045
+ )
1046
+
1047
+ # Show/hide time inputs based on checkbox
1048
+ include_time.change(
1049
+ fn=lambda x: gr.update(visible=x),
1050
+ inputs=[include_time],
1051
+ outputs=[time_group]
1052
+ )
1053
+
1054
+ return diffusion_interface
1055
+
1056
+ if __name__ == "__main__":
1057
+ # This allows the module to be run directly for testing
1058
+ app = create_diffusion_interface()
1059
+ app.launch()