File size: 48,449 Bytes
617c74e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from shapely.geometry import Point, Polygon
import random
import datetime
import gradio as gr
import tempfile
import os
import requests
import json
from typing import List, Tuple, Optional, Dict, Any, Union

def fetch_osm_exclusion_zones(bounds: Tuple[float, float, float, float], exclusion_types: List[str]) -> Optional[Any]:
    """

    Fetch exclusion zones from OpenStreetMap using Overpass API.

    

    Args:

        bounds: (min_lat, min_lon, max_lat, max_lon) bounding box

        exclusion_types: List of exclusion types to fetch

        

    Returns:

        GeoDataFrame with exclusion polygons or None if failed

    """
    try:
        import geopandas as gpd
        from shapely.geometry import Polygon, MultiPolygon, LineString
        
        # Overpass API endpoint
        overpass_url = "http://overpass-api.de/api/interpreter"
        
        # Build Overpass query based on selected exclusion types
        queries = []
        
        if "Water bodies" in exclusion_types:
            # Get both water polygons AND linear waterways
            queries.extend([
                # Water area polygons
                f'way["natural"="water"]({bounds[0]},{bounds[1]},{bounds[2]},{bounds[3]});',
                f'relation["natural"="water"]({bounds[0]},{bounds[1]},{bounds[2]},{bounds[3]});',
                f'way["landuse"="reservoir"]({bounds[0]},{bounds[1]},{bounds[2]},{bounds[3]});',
                f'way["water"="lake"]({bounds[0]},{bounds[1]},{bounds[2]},{bounds[3]});',
                f'way["water"="pond"]({bounds[0]},{bounds[1]},{bounds[2]},{bounds[3]});',
                # Linear waterways (rivers, streams, canals)
                f'way["waterway"="river"]({bounds[0]},{bounds[1]},{bounds[2]},{bounds[3]});',
                f'way["waterway"="stream"]({bounds[0]},{bounds[1]},{bounds[2]},{bounds[3]});',
                f'way["waterway"="canal"]({bounds[0]},{bounds[1]},{bounds[2]},{bounds[3]});'
            ])
        
        if "Parks & green spaces" in exclusion_types:
            queries.extend([
                f'way["leisure"="park"]({bounds[0]},{bounds[1]},{bounds[2]},{bounds[3]});',
                f'way["landuse"="forest"]({bounds[0]},{bounds[1]},{bounds[2]},{bounds[3]});',
                f'way["landuse"="grass"]({bounds[0]},{bounds[1]},{bounds[2]},{bounds[3]});',
                f'way["natural"="wood"]({bounds[0]},{bounds[1]},{bounds[2]},{bounds[3]});'
            ])
        
        if "Industrial areas" in exclusion_types:
            queries.extend([
                f'way["landuse"="industrial"]({bounds[0]},{bounds[1]},{bounds[2]},{bounds[3]});',
                f'way["landuse"="commercial"]({bounds[0]},{bounds[1]},{bounds[2]},{bounds[3]});'
            ])
        
        if "Major roads" in exclusion_types:
            queries.extend([
                f'way["highway"~"motorway|trunk|primary"]({bounds[0]},{bounds[1]},{bounds[2]},{bounds[3]});'
            ])
        
        if not queries:
            return None
        
        # Build complete Overpass query
        overpass_query = f"""

        [out:json][timeout:25];

        (

        {chr(10).join(queries)}

        );

        out geom;

        """
        
        print(f"Fetching OSM data for exclusion zones: {exclusion_types}")
        
        # Make request to Overpass API
        response = requests.get(overpass_url, params={'data': overpass_query})
        response.raise_for_status()
        
        data = response.json()
        
        if 'elements' not in data or not data['elements']:
            print("No exclusion zones found in the specified area")
            return None
        
        # Convert OSM data to polygons
        polygons = []
        zone_types = []
        
        for element in data['elements']:
            try:
                if element['type'] == 'way' and 'geometry' in element:
                    tags = element.get('tags', {})
                    
                    # Determine what type of feature this is
                    zone_type = None
                    if 'natural' in tags and tags['natural'] == 'water':
                        zone_type = 'Water'
                    elif 'landuse' in tags and tags['landuse'] == 'reservoir':
                        zone_type = 'Water'
                    elif 'water' in tags:
                        zone_type = 'Water'
                    elif 'waterway' in tags and tags['waterway'] in ['river', 'stream', 'canal']:
                        zone_type = 'Water'
                    elif 'leisure' in tags and tags['leisure'] == 'park':
                        zone_type = 'Park'
                    elif 'landuse' in tags and tags['landuse'] in ['forest', 'grass']:
                        zone_type = 'Green space'
                    elif 'natural' in tags and tags['natural'] == 'wood':
                        zone_type = 'Forest'
                    elif 'landuse' in tags and tags['landuse'] in ['industrial', 'commercial']:
                        zone_type = 'Industrial/Commercial'
                    elif 'highway' in tags:
                        zone_type = 'Major road'
                    
                    if zone_type is None:
                        continue
                    
                    # Create polygon from way geometry
                    coords = [(node['lon'], node['lat']) for node in element['geometry']]
                    
                    # Handle different geometry types
                    if 'waterway' in tags or 'highway' in tags:
                        # For linear features (rivers, roads), create a buffered polygon from the line
                        if len(coords) >= 2:
                            try:
                                line = LineString(coords)
                                # Buffer size depends on feature type
                                if 'waterway' in tags:
                                    if tags['waterway'] == 'river':
                                        buffer_size = 50 / 111320  # Rivers: ~50 meters
                                    elif tags['waterway'] == 'canal':
                                        buffer_size = 30 / 111320  # Canals: ~30 meters  
                                    else:  # streams
                                        buffer_size = 20 / 111320  # Streams: ~20 meters
                                else:  # highways
                                    buffer_size = 25 / 111320  # Roads: ~25 meters
                                
                                polygon = line.buffer(buffer_size)
                                if polygon.is_valid and polygon.area > 0:
                                    polygons.append(polygon)
                                    zone_types.append(zone_type)
                            except Exception as e:
                                print(f"Error buffering linear feature: {str(e)}")
                                continue
                    else:
                        # For areas, create closed polygons
                        if len(coords) > 2:
                            # Close polygon if not already closed
                            if coords[0] != coords[-1]:
                                coords.append(coords[0])
                            
                            if len(coords) >= 4:  # Valid polygon needs at least 4 points
                                try:
                                    polygon = Polygon(coords)
                                    if polygon.is_valid and polygon.area > 0:
                                        polygons.append(polygon)
                                        zone_types.append(zone_type)
                                except Exception as e:
                                    print(f"Error creating polygon: {str(e)}")
                                    continue
                                    
            except Exception as e:
                print(f"Error processing OSM element: {str(e)}")
                continue
        
        if not polygons:
            print("No valid polygons found in OSM data")
            return None
        
        # Create GeoDataFrame
        gdf = gpd.GeoDataFrame(
            {'zone_type': zone_types}, 
            geometry=polygons, 
            crs='EPSG:4326'
        )
        
        print(f"Successfully fetched {len(gdf)} exclusion zones from OpenStreetMap")
        print(f"Zone types found: {gdf['zone_type'].value_counts().to_dict()}")
        return gdf
        
    except ImportError:
        print("GeoPandas not available for OSM processing")
        return None
    except requests.exceptions.RequestException as e:
        print(f"Error fetching data from OpenStreetMap: {str(e)}")
        return None
    except Exception as e:
        print(f"Error processing OpenStreetMap data: {str(e)}")
        return None

def calculate_bounds_from_points(input_df: pd.DataFrame, buffer_km: float = 2.0) -> Tuple[float, float, float, float]:
    """Calculate bounding box around input points with buffer"""
    # Get min/max coordinates
    min_lat = input_df['lat'].min()
    max_lat = input_df['lat'].max()
    min_lon = input_df['lon'].min()
    max_lon = input_df['lon'].max()
    
    # Add buffer (approximate conversion from km to degrees)
    buffer_deg = buffer_km / 111.0  # Rough conversion: 1 degree ≈ 111 km
    
    return (
        min_lat - buffer_deg,  # min_lat
        min_lon - buffer_deg,  # min_lon  
        max_lat + buffer_deg,  # max_lat
        max_lon + buffer_deg   # max_lon
    )

class SpatialDiffuser:
    """

    Class for performing spatial diffusion - takes points with counts and diffuses them

    according to specified distributions within given radii, with optional exclusion zones.

    """
    
    def __init__(self):
        self.distribution_methods = {
            "uniform": self._uniform_distribution,
            "normal": self._normal_distribution,
            "exponential_decay": self._exponential_decay,
            "distance_weighted": self._distance_weighted
        }
        
    def diffuse_points(self, 

                       input_data: pd.DataFrame, 

                       distribution_type: str = "uniform",

                       global_radius: Optional[float] = None,

                       time_start: Optional[datetime.datetime] = None, 

                       time_end: Optional[datetime.datetime] = None,

                       seed: Optional[int] = None,

                       exclusion_zones_gdf: Optional[Any] = None) -> pd.DataFrame:
        """

        Generate diffused points based on input coordinates and counts.

        

        Args:

            input_data: DataFrame with columns: lat, lon, count, radius (optional)

            distribution_type: Type of spatial distribution to use

            global_radius: Radius to use for all points if not specified individually (in meters)

            time_start: Start time for temporal distribution

            time_end: End time for temporal distribution

            seed: Random seed for reproducible results

            exclusion_zones_gdf: GeoDataFrame with polygons to exclude points from

            

        Returns:

            DataFrame with columns: lat, lon, source_id, timestamp (if temporal)

        """
        # Set random seed if provided
        if seed is not None:
            np.random.seed(seed)
            random.seed(seed)
            
        if distribution_type not in self.distribution_methods:
            raise ValueError(f"Distribution type '{distribution_type}' not supported. Choose from: {list(self.distribution_methods.keys())}")
        
        # Initialize list to hold all generated points
        all_points = []
        
        # Generate points for each input location
        for idx, row in input_data.iterrows():
            # Get radius (either from row or global)
            radius = row.get('radius', global_radius)
            if radius is None:
                raise ValueError("Radius must be specified either globally or per point")
            
            # Get count
            count = int(row['count'])
            if count <= 0:
                continue
                
            # Generate points with exclusion zone filtering
            new_points = self._generate_points_with_exclusions(
                lat=row['lat'],
                lon=row['lon'],
                count=count,
                radius=radius,
                distribution_type=distribution_type,
                exclusion_zones_gdf=exclusion_zones_gdf
            )
            
            # Add source identifier
            source_ids = [idx] * len(new_points)
            
            # Add timestamps if temporal distribution is requested
            if time_start is not None and time_end is not None:
                timestamps = self._generate_timestamps(len(new_points), time_start, time_end)
                
                # Combine points with metadata
                for i, point in enumerate(new_points):
                    all_points.append({
                        'lat': point[0],
                        'lon': point[1],
                        'source_id': source_ids[i],
                        'timestamp': timestamps[i]
                    })
            else:
                # Combine points with metadata without timestamps
                for i, point in enumerate(new_points):
                    all_points.append({
                        'lat': point[0],
                        'lon': point[1],
                        'source_id': source_ids[i]
                    })
        
        # Convert to DataFrame
        result = pd.DataFrame(all_points)
        return result
    
    def _generate_points_with_exclusions(self, lat: float, lon: float, count: int, radius: float, 

                                       distribution_type: str, exclusion_zones_gdf: Optional[Any] = None) -> List[Tuple[float, float]]:
        """Generate points while avoiding exclusion zones"""
        
        if exclusion_zones_gdf is None or len(exclusion_zones_gdf) == 0:
            # No exclusion zones, use normal generation
            return self.distribution_methods[distribution_type](lat, lon, count, radius)
        
        try:
            import geopandas as gpd
            from shapely.geometry import Point
            
            valid_points = []
            max_attempts = count * 10  # Generate up to 10x more points to account for exclusions
            attempts = 0
            
            # Ensure exclusion zones are in the same CRS as our points (WGS84)
            if exclusion_zones_gdf.crs is None:
                exclusion_zones_gdf = exclusion_zones_gdf.set_crs('EPSG:4326')
            elif exclusion_zones_gdf.crs != 'EPSG:4326':
                exclusion_zones_gdf = exclusion_zones_gdf.to_crs('EPSG:4326')
            
            while len(valid_points) < count and attempts < max_attempts:
                # Generate a batch of points
                batch_size = min(count * 2, max_attempts - attempts)
                candidate_points = self.distribution_methods[distribution_type](
                    lat, lon, batch_size, radius
                )
                
                # Check each point against exclusion zones
                for point in candidate_points:
                    if len(valid_points) >= count:
                        break
                        
                    point_geom = Point(point[1], point[0])  # lon, lat for Point
                    
                    # Check if point intersects with any exclusion zone
                    is_excluded = False
                    for _, exclusion_zone in exclusion_zones_gdf.iterrows():
                        if point_geom.intersects(exclusion_zone.geometry):
                            is_excluded = True
                            break
                    
                    if not is_excluded:
                        valid_points.append(point)
                
                attempts += batch_size
            
            # If we couldn't generate enough valid points, warn the user
            if len(valid_points) < count:
                print(f"Warning: Could only generate {len(valid_points)} valid points out of {count} requested for location ({lat}, {lon}). Exclusion zones may be too restrictive.")
            
            return valid_points
            
        except ImportError:
            print("GeoPandas not available for exclusion zone processing. Generating points without exclusions.")
            return self.distribution_methods[distribution_type](lat, lon, count, radius)
        except Exception as e:
            print(f"Error processing exclusion zones: {str(e)}. Generating points without exclusions.")
            return self.distribution_methods[distribution_type](lat, lon, count, radius)

    def _uniform_distribution(self, lat: float, lon: float, count: int, radius: float) -> List[Tuple[float, float]]:
        """Generate points uniformly distributed within a circle"""
        points = []
        
        for _ in range(count):
            # Generate a random angle and distance
            angle = random.uniform(0, 2 * np.pi)
            # Uniform distribution needs square root to avoid clustering in center
            r = radius * np.sqrt(random.uniform(0, 1))
            
            # Convert polar coordinates to Cartesian
            x = r * np.cos(angle)
            y = r * np.sin(angle)
            
            # Convert meters to approximate degrees (this is a simplification)
            # A more accurate implementation would use proper geographic projections
            lat_offset = y / 111320  # 1 degree latitude is approximately 111320 meters
            # Longitude degrees vary with latitude, so adjust accordingly
            lon_offset = x / (111320 * np.cos(np.radians(lat)))
            
            new_lat = lat + lat_offset
            new_lon = lon + lon_offset
            
            points.append((new_lat, new_lon))
            
        return points
    
    def _normal_distribution(self, lat: float, lon: float, count: int, radius: float) -> List[Tuple[float, float]]:
        """Generate points with normal distribution (more concentrated near center)"""
        points = []
        
        # Standard deviation as a fraction of radius
        std_dev = radius / 3  # 3 sigma rule - 99.7% of points within radius
        
        for _ in range(count):
            # Generate points using normal distribution
            while True:
                # Generate x and y offsets using normal distribution
                x = np.random.normal(0, std_dev)
                y = np.random.normal(0, std_dev)
                
                # Calculate distance from center
                distance = np.sqrt(x**2 + y**2)
                
                # If point is within radius, keep it
                if distance <= radius:
                    break
            
            # Convert meters to approximate degrees
            lat_offset = y / 111320
            lon_offset = x / (111320 * np.cos(np.radians(lat)))
            
            new_lat = lat + lat_offset
            new_lon = lon + lon_offset
            
            points.append((new_lat, new_lon))
            
        return points
    
    def _exponential_decay(self, lat: float, lon: float, count: int, radius: float) -> List[Tuple[float, float]]:
        """Generate points with exponential decay from center"""
        points = []
        
        # Rate parameter - controls how quickly density decreases with distance
        rate = 3.0 / radius  # Higher value = steeper decay
        
        for _ in range(count):
            # Generate random angle
            angle = random.uniform(0, 2 * np.pi)
            
            # Generate distance with exponential distribution
            # Use rejection sampling to ensure points are within radius
            while True:
                # Generate exponential random variable
                r = random.expovariate(rate)
                if r <= radius:
                    break
            
            # Convert polar coordinates to Cartesian
            x = r * np.cos(angle)
            y = r * np.sin(angle)
            
            # Convert meters to approximate degrees
            lat_offset = y / 111320
            lon_offset = x / (111320 * np.cos(np.radians(lat)))
            
            new_lat = lat + lat_offset
            new_lon = lon + lon_offset
            
            points.append((new_lat, new_lon))
            
        return points
    
    def _distance_weighted(self, lat: float, lon: float, count: int, radius: float) -> List[Tuple[float, float]]:
        """

        Generate points with a custom distance-weighted distribution

        (more points at medium distances than at center or edge)

        """
        points = []
        
        for _ in range(count):
            # Generate random angle
            angle = random.uniform(0, 2 * np.pi)
            
            # Custom distribution - more weight at middle distances
            # Generate r² with beta distribution (concentration in middle)
            r_squared = random.betavariate(2, 2)  # Beta with alpha=beta=2 has peak in middle
            r = np.sqrt(r_squared) * radius
            
            # Convert polar coordinates to Cartesian
            x = r * np.cos(angle)
            y = r * np.sin(angle)
            
            # Convert meters to approximate degrees
            lat_offset = y / 111320
            lon_offset = x / (111320 * np.cos(np.radians(lat)))
            
            new_lat = lat + lat_offset
            new_lon = lon + lon_offset
            
            points.append((new_lat, new_lon))
            
        return points
    
    def _generate_timestamps(self, count: int, start_time: datetime.datetime, end_time: datetime.datetime) -> List[datetime.datetime]:
        """Generate uniformly distributed timestamps"""
        timestamps = []
        
        # Convert to timestamps for easier calculations
        start_ts = start_time.timestamp()
        end_ts = end_time.timestamp()
        
        for _ in range(count):
            # Generate a random timestamp between start and end
            random_ts = random.uniform(start_ts, end_ts)
            timestamp = datetime.datetime.fromtimestamp(random_ts)
            timestamps.append(timestamp)
            
        # Sort timestamps chronologically
        timestamps.sort()
            
        return timestamps

def create_visualization(input_df, output_df, show_basemap=False, exclusion_zones_gdf=None):
    """Create visualization of input and diffused points"""
    fig, ax = plt.subplots(figsize=(12, 10))
    
    # Set background color
    fig.patch.set_facecolor('white')
    ax.set_facecolor('#f8f9fa')
    
    # Define colors for different exclusion zone types
    exclusion_colors = {
        'Water': '#4FC3F7',           # Light blue
        'Park': '#66BB6A',            # Green
        'Green space': '#81C784',     # Light green
        'Forest': '#4CAF50',          # Dark green
        'Industrial/Commercial': '#90A4AE',  # Grey
        'Major road': '#FFD54F',      # Yellow
        'Other': '#FFAB91'            # Light orange
    }
    
    # If basemap is requested, convert to Web Mercator and add basemap
    if show_basemap:
        try:
            import contextily as ctx
            import geopandas as gpd
            from shapely.geometry import Point
            
            # Create GeoDataFrames for proper projection
            input_gdf = gpd.GeoDataFrame(
                input_df, 
                geometry=[Point(lon, lat) for lon, lat in zip(input_df['lon'], input_df['lat'])],
                crs='EPSG:4326'
            )
            output_gdf = gpd.GeoDataFrame(
                output_df,
                geometry=[Point(lon, lat) for lon, lat in zip(output_df['lon'], output_df['lat'])],
                crs='EPSG:4326'
            )
            
            # Convert to Web Mercator for basemap compatibility
            input_gdf_merc = input_gdf.to_crs('EPSG:3857')
            output_gdf_merc = output_gdf.to_crs('EPSG:3857')
            
            # Plot exclusion zones first (if provided) with color coding
            if exclusion_zones_gdf is not None and len(exclusion_zones_gdf) > 0:
                try:
                    exclusion_zones_merc = exclusion_zones_gdf.to_crs('EPSG:3857')
                    
                    # Group by zone type and plot with appropriate colors
                    plotted_types = set()
                    for zone_type in exclusion_zones_merc['zone_type'].unique():
                        zone_subset = exclusion_zones_merc[exclusion_zones_merc['zone_type'] == zone_type]
                        color = exclusion_colors.get(zone_type, exclusion_colors['Other'])
                        
                        # Only add label for first occurrence of each type
                        label = zone_type if zone_type not in plotted_types else None
                        if label:
                            plotted_types.add(zone_type)
                        
                        zone_subset.plot(ax=ax, color=color, alpha=0.6, edgecolor='white', 
                                       linewidth=0.5, label=label)
                        
                except Exception as e:
                    print(f"Error plotting exclusion zones: {str(e)}")
            
            # Extract coordinates for plotting
            input_x = input_gdf_merc.geometry.x
            input_y = input_gdf_merc.geometry.y
            output_x = output_gdf_merc.geometry.x
            output_y = output_gdf_merc.geometry.y
            
            # Plot diffused points first (so they appear behind source points)
            ax.scatter(output_x, output_y, 
                       alpha=0.7, color='#FF9800', s=12, label=f'Generated Points (n={len(output_df)})', 
                       edgecolors='white', linewidth=0.3)
            
            # Draw radius circles first (so they appear behind everything else)
            for idx, row in input_df.iterrows():
                radius = row.get('radius', None)
                
                if radius is not None:
                    # Convert center point to Web Mercator
                    center_point = gpd.GeoDataFrame(
                        [1], geometry=[Point(row['lon'], row['lat'])], crs='EPSG:4326'
                    ).to_crs('EPSG:3857')
                    
                    center_x = center_point.geometry.x.iloc[0]
                    center_y = center_point.geometry.y.iloc[0]
                    
                    # Draw circle (radius is already in meters, which matches Web Mercator units)
                    circle = plt.Circle((center_x, center_y), radius, 
                                       fill=False, color='#9C27B0', linestyle='--', 
                                       alpha=0.5, linewidth=2)
                    ax.add_patch(circle)
            
            # Plot source points with circles sized by count
            min_size = 100
            max_size = 800
            if len(input_df) > 1:
                size_range = input_df['count'].max() - input_df['count'].min()
                if size_range > 0:
                    sizes = min_size + (input_df['count'] - input_df['count'].min()) / size_range * (max_size - min_size)
                else:
                    sizes = [min_size] * len(input_df)
            else:
                sizes = [max_size]
            
            # Plot source points in purple
            ax.scatter(input_x, input_y, 
                        s=sizes, c='#9C27B0', alpha=0.9, 
                        edgecolors='white', linewidth=2,
                        label='Source Points (size = count)', zorder=5)
            
            # Add count labels next to source points
            for idx, row in input_df.iterrows():
                point_merc = gpd.GeoDataFrame(
                    [1], geometry=[Point(row['lon'], row['lat'])], crs='EPSG:4326'
                ).to_crs('EPSG:3857')
                
                x_merc = point_merc.geometry.x.iloc[0]
                y_merc = point_merc.geometry.y.iloc[0]
                
                ax.annotate(f'{int(row["count"])}', 
                           (x_merc, y_merc), 
                           xytext=(8, 8), textcoords='offset points',
                           fontsize=10, fontweight='bold', color='white',
                           bbox=dict(boxstyle='round,pad=0.3', facecolor='#9C27B0', alpha=0.8),
                           zorder=6)
            
            # Add basemap
            try:
                ctx.add_basemap(ax, crs='EPSG:3857', source=ctx.providers.CartoDB.Positron, alpha=0.8)
                basemap_added = True
            except Exception as e:
                print(f"Could not add basemap: {str(e)}")
                basemap_added = False
            
            # Set axis labels for Web Mercator
            ax.set_xlabel('Easting (Web Mercator)', fontsize=12)
            ax.set_ylabel('Northing (Web Mercator)', fontsize=12)
            
            # Use projected coordinates for limits
            x_coords = list(input_x) + list(output_x)
            y_coords = list(input_y) + list(output_y)
            
        except ImportError:
            print("Contextily not available for basemap. Falling back to simple plot.")
            show_basemap = False
        except Exception as e:
            print(f"Error creating basemap: {str(e)}. Falling back to simple plot.")
            show_basemap = False
    
    # Fallback to simple plot if basemap fails or is not requested
    if not show_basemap:
        # Plot exclusion zones first (if provided) with color coding
        if exclusion_zones_gdf is not None and len(exclusion_zones_gdf) > 0:
            try:
                # Ensure exclusion zones are in WGS84
                if exclusion_zones_gdf.crs != 'EPSG:4326':
                    exclusion_zones_gdf = exclusion_zones_gdf.to_crs('EPSG:4326')
                
                # Plot zones by type with appropriate colors
                plotted_types = set()
                for idx, zone in exclusion_zones_gdf.iterrows():
                    zone_type = zone.get('zone_type', 'Other')
                    color = exclusion_colors.get(zone_type, exclusion_colors['Other'])
                    
                    # Only add label for first occurrence of each type
                    label = zone_type if zone_type not in plotted_types else None
                    if label:
                        plotted_types.add(zone_type)
                    
                    if zone.geometry.geom_type == 'Polygon':
                        x, y = zone.geometry.exterior.xy
                        ax.fill(x, y, color=color, alpha=0.6, edgecolor='white', 
                               linewidth=0.5, label=label)
                    elif zone.geometry.geom_type == 'MultiPolygon':
                        for poly in zone.geometry.geoms:
                            x, y = poly.exterior.xy
                            ax.fill(x, y, color=color, alpha=0.6, edgecolor='white', 
                                   linewidth=0.5, label=label)
                            label = None  # Only label the first polygon
                            
            except Exception as e:
                print(f"Error plotting exclusion zones: {str(e)}")
        
        # Plot diffused points first (so they appear behind source points) - orange
        ax.scatter(output_df['lon'], output_df['lat'], 
                   alpha=0.7, color='#FF9800', s=12, label=f'Generated Points (n={len(output_df)})', 
                   edgecolors='white', linewidth=0.3)
        
        # Draw radius circles first (so they appear behind everything else) - purple
        for idx, row in input_df.iterrows():
            radius = row.get('radius', None)
            
            if radius is not None:
                # Approximate conversion from meters to degrees
                radius_deg_lat = radius / 111320
                radius_deg_lon = radius / (111320 * np.cos(np.radians(row['lat'])))
                
                # Use the average as an approximation
                radius_deg = (radius_deg_lat + radius_deg_lon) / 2
                
                # Draw circle in purple
                circle = plt.Circle((row['lon'], row['lat']), radius_deg, 
                                   fill=False, color='#9C27B0', linestyle='--', 
                                   alpha=0.5, linewidth=2)
                ax.add_patch(circle)
        
        # Plot source points with circles sized by count - purple
        min_size = 100
        max_size = 800
        if len(input_df) > 1:
            size_range = input_df['count'].max() - input_df['count'].min()
            if size_range > 0:
                sizes = min_size + (input_df['count'] - input_df['count'].min()) / size_range * (max_size - min_size)
            else:
                sizes = [min_size] * len(input_df)
        else:
            sizes = [max_size]
        
        # Plot source points in purple
        ax.scatter(input_df['lon'], input_df['lat'], 
                    s=sizes, c='#9C27B0', alpha=0.9, 
                    edgecolors='white', linewidth=2,
                    label='Source Points (size = count)', zorder=5)
        
        # Add count labels next to source points with purple background
        for idx, row in input_df.iterrows():
            ax.annotate(f'{int(row["count"])}', 
                       (row['lon'], row['lat']), 
                       xytext=(8, 8), textcoords='offset points',
                       fontsize=10, fontweight='bold', color='white',
                       bbox=dict(boxstyle='round,pad=0.3', facecolor='#9C27B0', alpha=0.8),
                       zorder=6)
        
        # Standard coordinate labels
        ax.set_xlabel('Longitude', fontsize=12)
        ax.set_ylabel('Latitude', fontsize=12)
        
        # Use original coordinates for limits
        x_coords = list(input_df['lon']) + list(output_df['lon'])
        y_coords = list(input_df['lat']) + list(output_df['lat'])
    
    # Common styling
    title = 'Spatial Diffusion Results'
    if show_basemap:
        title += ' (with Basemap)'
    if exclusion_zones_gdf is not None and len(exclusion_zones_gdf) > 0:
        title += ' - Exclusion Zones Applied'
    subtitle = 'Purple source points sized by count, orange generated points, dashed circles show diffusion radius'
    
    ax.set_title(f'{title}\n{subtitle}', 
                fontsize=14, fontweight='bold', pad=20)
    
    # Legend with better positioning
    legend = ax.legend(loc='upper right', bbox_to_anchor=(1, 1), 
                      frameon=True, fancybox=True, shadow=True)
    legend.get_frame().set_facecolor('white')
    legend.get_frame().set_alpha(0.9)
    
    # Add grid (lighter for basemap)
    grid_alpha = 0.2 if show_basemap else 0.3
    ax.grid(True, alpha=grid_alpha, linestyle='-', linewidth=0.5)
    
    # Make equal aspect ratio
    ax.set_aspect('equal', 'box')
    
    # Add some padding around the data
    x_margin = (max(x_coords) - min(x_coords)) * 0.1
    y_margin = (max(y_coords) - min(y_coords)) * 0.1
    
    if x_margin == 0:  # Handle case where all points have same x-coordinate
        x_margin = 1000 if show_basemap else 0.01
    if y_margin == 0:  # Handle case where all points have same y-coordinate
        y_margin = 1000 if show_basemap else 0.01
        
    ax.set_xlim(min(x_coords) - x_margin, max(x_coords) + x_margin)
    ax.set_ylim(min(y_coords) - y_margin, max(y_coords) + y_margin)
    
    # Tight layout
    plt.tight_layout()
    
    return fig

def process_csv(file_obj, distribution_type, global_radius, show_basemap, auto_exclusions, exclusion_file, include_time, time_start, time_end, seed):
    """Process input CSV and generate diffused points"""
    try:
        # Read input CSV
        df = pd.read_csv(file_obj.name)
        
        # Validate required columns
        required_cols = ['lat', 'lon', 'count']
        if not all(col in df.columns for col in required_cols):
            return None, f"Error: CSV must contain columns: {', '.join(required_cols)}"
        
        # Convert global_radius to float if provided
        if global_radius and global_radius.strip():
            try:
                global_radius = float(global_radius)
            except ValueError:
                return None, "Error: Global radius must be a number"
        else:
            global_radius = None
            # If global radius not provided, check for radius column
            if 'radius' not in df.columns:
                return None, "Error: Either provide a global radius or include a 'radius' column in the CSV"
        
        # Convert seed to int if provided
        if seed and seed.strip():
            try:
                seed = int(seed)
            except ValueError:
                return None, "Error: Seed must be an integer"
        else:
            seed = None
        
        # Process exclusion zones
        exclusion_zones_gdf = None
        
        # First, try manual file upload (takes priority)
        if exclusion_file is not None:
            try:
                import geopandas as gpd
                
                # Determine file type and read accordingly
                file_extension = os.path.splitext(exclusion_file.name)[1].lower()
                
                if file_extension in ['.geojson', '.json']:
                    exclusion_zones_gdf = gpd.read_file(exclusion_file.name)
                elif file_extension == '.gpkg':
                    exclusion_zones_gdf = gpd.read_file(exclusion_file.name)
                elif file_extension == '.shp':
                    exclusion_zones_gdf = gpd.read_file(exclusion_file.name)
                else:
                    return None, f"Error: Unsupported exclusion zone file format: {file_extension}"
                
                # Ensure CRS is set
                if exclusion_zones_gdf.crs is None:
                    exclusion_zones_gdf = exclusion_zones_gdf.set_crs('EPSG:4326')
                
                print(f"Loaded {len(exclusion_zones_gdf)} custom exclusion zones from {exclusion_file.name}")
                
            except ImportError:
                return None, "Error: GeoPandas required for exclusion zones processing"
            except Exception as e:
                return None, f"Error reading exclusion zones file: {str(e)}"
        
        # If no manual file, try automatic exclusions from OpenStreetMap
        elif auto_exclusions and len(auto_exclusions) > 0:
            try:
                # Calculate bounds around input points
                bounds = calculate_bounds_from_points(df)
                print(f"Fetching automatic exclusions for bounds: {bounds}")
                
                # Fetch OSM data
                exclusion_zones_gdf = fetch_osm_exclusion_zones(bounds, auto_exclusions)
                
                if exclusion_zones_gdf is not None:
                    print(f"Fetched {len(exclusion_zones_gdf)} exclusion zones from OpenStreetMap")
                else:
                    print("No exclusion zones found in OpenStreetMap for this area")
                    
            except Exception as e:
                print(f"Warning: Could not fetch automatic exclusions: {str(e)}")
                # Continue without exclusions rather than failing completely
                exclusion_zones_gdf = None
        
        # Process time if requested
        if include_time:
            if not time_start or not time_end:
                return None, "Error: If time distribution is enabled, both start and end times must be provided"
            try:
                time_start_dt = datetime.datetime.strptime(time_start, "%Y-%m-%d %H:%M:%S")
                time_end_dt = datetime.datetime.strptime(time_end, "%Y-%m-%d %H:%M:%S")
                if time_start_dt >= time_end_dt:
                    return None, "Error: End time must be after start time"
            except ValueError:
                return None, "Error: Invalid time format. Use YYYY-MM-DD HH:MM:SS"
        else:
            time_start_dt = None
            time_end_dt = None
        
        # Create diffuser and generate diffused points
        diffuser = SpatialDiffuser()
        result_df = diffuser.diffuse_points(
            input_data=df,
            distribution_type=distribution_type,
            global_radius=global_radius,
            time_start=time_start_dt,
            time_end=time_end_dt,
            seed=seed,
            exclusion_zones_gdf=exclusion_zones_gdf
        )
        
        # Create temporary CSV file
        temp_file = "diffused_points.csv"
        result_df.to_csv(temp_file, index=False)
        
        # Create visualization with basemap and exclusion zones
        fig = create_visualization(df, result_df, show_basemap, exclusion_zones_gdf)
        
        return fig, temp_file
        
    except Exception as e:
        return None, f"Error: {str(e)}"

def create_diffusion_interface():
    """Create Gradio interface for the spatial diffusion tool"""
    
    with gr.Blocks() as diffusion_interface:
        gr.Markdown("## 🗺️ Spatial Diffusion Tool")
        
        with gr.Row():
            with gr.Column(scale=1):
                # Move description into the left column for better space usage
                gr.Markdown("""

                ### About This Tool

                Transform aggregated geographic points with counts into individual points using spatial diffusion methods.

                

                **Input CSV Format:**

                - `lat`: Latitude of source point

                - `lon`: Longitude of source point  

                - `count`: Number of points to generate

                - `radius`: (Optional) Diffusion radius in meters

                

                **Distribution Types:**

                - **Uniform**: Equal probability throughout circle

                - **Normal**: Higher density near center

                - **Exponential Decay**: Density decreases from center

                - **Distance-Weighted**: More points at medium distances

                """)
                
                # Input controls
                input_file = gr.File(label="Input CSV File", file_types=[".csv"])
                
                # Distribution options grouped together
                gr.Markdown("### 🎯 Distribution Options")
                with gr.Row():
                    distribution = gr.Dropdown(
                        choices=["uniform", "normal", "exponential_decay", "distance_weighted"],
                        value="uniform",
                        label="Distribution Type",
                        scale=2
                    )
                    seed = gr.Textbox(
                        label="Random Seed (optional)", 
                        placeholder="e.g. 42",
                        scale=1
                    )
                
                global_radius = gr.Textbox(
                    label="Global Radius (meters)", 
                    placeholder="Only if radius column not in CSV"
                )
                
                # Temporal controls in distribution section
                with gr.Accordion("⏰ Temporal Distribution (Optional)", open=False):
                    include_time = gr.Checkbox(label="Enable Temporal Distribution", value=False)
                    with gr.Group() as time_group:
                        time_start = gr.Textbox(
                            label="Start Time", 
                            placeholder="YYYY-MM-DD HH:MM:SS"
                        )
                        time_end = gr.Textbox(
                            label="End Time", 
                            placeholder="YYYY-MM-DD HH:MM:SS"
                        )
                
                # Map and exclusion options grouped together
                gr.Markdown("### 🗺️ Map & Exclusion Options")
                show_basemap = gr.Checkbox(
                    label="Show underlying map (requires internet)", 
                    value=False
                )
                gr.Markdown("*Adds geographic context with street/satellite imagery*")
                
                # Automatic exclusion zones - no default selection
                auto_exclusions = gr.CheckboxGroup(
                    label="Auto-exclude from OpenStreetMap:",
                    choices=["Water bodies", "Parks & green spaces", "Industrial areas", "Major roads"],
                    value=[]  # No default selections
                )
                
                # Advanced manual exclusion zones
                with gr.Accordion("🔧 Advanced: Custom Exclusion Zones", open=False):
                    exclusion_file = gr.File(
                        label="Upload custom shapefile (optional)",
                        file_types=[".geojson", ".json", ".gpkg", ".shp"]
                    )
                    gr.Markdown("*Overrides automatic exclusions if provided*")
                
                process_btn = gr.Button(
                    "🎯 Generate Diffused Points", 
                    variant="primary", 
                    size="lg"
                )
            
            with gr.Column(scale=2):
                # Give more space to visualization
                plot_output = gr.Plot(
                    label="📍 Spatial Diffusion Visualization",
                    show_label=True
                )
                
                with gr.Row():
                    with gr.Column(scale=2):
                        file_output = gr.File(label="📥 Download Generated Points")
                    with gr.Column(scale=1):
                        gr.Markdown(
                            """

                            **Legend:**  

                            🟣 Source points (sized by count)  

                            🟠 Generated points  

                            ⭕ Diffusion radius  

                            🟦 Water bodies  

                            🟢 Parks & green spaces  

                            ⬜ Industrial areas  

                            🟡 Major roads

                            """
                        )
        
        # Set up event handlers
        process_btn.click(
            fn=process_csv,
            inputs=[input_file, distribution, global_radius, show_basemap, auto_exclusions, exclusion_file, include_time, time_start, time_end, seed],
            outputs=[plot_output, file_output]
        )
        
        # Show/hide time inputs based on checkbox
        include_time.change(
            fn=lambda x: gr.update(visible=x),
            inputs=[include_time],
            outputs=[time_group]
        )
    
    return diffusion_interface

if __name__ == "__main__":
    # This allows the module to be run directly for testing
    app = create_diffusion_interface()
    app.launch()