arnavmishra4 commited on
Commit
c5122a2
Β·
verified Β·
1 Parent(s): ca31e65

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +764 -700
main.py CHANGED
@@ -1,701 +1,765 @@
1
- """
2
- Archaeological Site Detection - Gradio Frontend
3
- ================================================
4
- Interactive web interface for SONAR 2.0 archaeological detection system
5
- with GATE model predictions and multi-model analysis.
6
-
7
- Run: python gradio_app.py
8
- """
9
-
10
- import gradio as gr
11
- import numpy as np
12
- import pandas as pd
13
- import matplotlib.pyplot as plt
14
- import matplotlib
15
- matplotlib.use('Agg')
16
- from matplotlib.patches import Rectangle
17
- import plotly.graph_objects as go
18
- from pathlib import Path
19
- import torch
20
- import json
21
- from typing import Dict, List, Tuple, Optional
22
- import io
23
- from PIL import Image
24
- import warnings
25
- warnings.filterwarnings('ignore')
26
-
27
- # Import your existing utilities
28
- from utils import (
29
- ResUNetAutoencoder, ResUNetEncoder, load_patches,
30
- load_model, load_kmeans_model, compute_autoencoder_probabilities,
31
- compute_iforest_probabilities, compute_kmeans_probabilities,
32
- load_unified_probability_matrix
33
- )
34
- from arch_similarity_utils import compute_arch_similarity_channel
35
- from visualization import (
36
- generate_2d_patch_preview, generate_3d_terrain_json,
37
- generate_gate_prediction_heatmap, generate_gate_positive_patches_visualization,
38
- generate_probability_matrix_visualization, generate_full_aoi_heatmap,
39
- generate_gate_statistics_json
40
- )
41
-
42
- # ==============================================================================
43
- # CONFIGURATION
44
- # ==============================================================================
45
-
46
- class Config:
47
- """Application configuration"""
48
- # Model paths
49
- AUTOENCODER_PATH = Path('models/best_model_aoi.pth')
50
- ENCODER_DIM = 128
51
- IFOREST_PATH = Path('models/isolation_forest_model_128dim.pkl')
52
- KMEANS_PATH = Path('models/kmeans_model_128dim.pkl')
53
- ARCH_EMBEDDINGS_CSV = Path('GATE/Arch_embedding_only_128dim.csv')
54
- GATE_MODEL_PKL = Path('GATE/gate_mlp_model.pkl')
55
- GATE_SCALER_PATH = Path('GATE/gate_scaler.pkl')
56
-
57
- # Data paths
58
- PATCHES_DIR = Path('patches/patches_final')
59
- UNIFIED_PROB_DIR = Path('src/unified_probability_matrices_with_gate')
60
-
61
- # Processing
62
- BATCH_SIZE = 32
63
- DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
64
-
65
- # UI
66
- THEME = gr.themes.Soft(
67
- primary_hue="blue",
68
- secondary_hue="slate",
69
- )
70
-
71
- config = Config()
72
-
73
- # ==============================================================================
74
- # MODEL MANAGER
75
- # ==============================================================================
76
-
77
- class ModelManager:
78
- """Manages all trained models"""
79
-
80
- def __init__(self):
81
- self.models_loaded = False
82
- self.autoencoder = None
83
- self.encoder = None
84
- self.iforest = None
85
- self.scaler_iforest = None
86
- self.kmeans = None
87
- self.scaler_kmeans = None
88
- self.gate_model = None
89
- self.gate_scaler = None
90
-
91
- def load_all_models(self):
92
- """Load all models"""
93
- status_messages = []
94
-
95
- try:
96
- # Autoencoder
97
- self.autoencoder = ResUNetAutoencoder(in_channels=7).to(config.DEVICE)
98
- self.autoencoder.load_state_dict(
99
- torch.load(config.AUTOENCODER_PATH, map_location=config.DEVICE)
100
- )
101
- self.autoencoder.eval()
102
- status_messages.append("βœ… Autoencoder loaded")
103
-
104
- # Encoder
105
- self.encoder = ResUNetEncoder(
106
- in_channels=7,
107
- embedding_dim=config.ENCODER_DIM
108
- ).to(config.DEVICE)
109
- self.encoder.load_from_autoencoder(str(config.AUTOENCODER_PATH))
110
- self.encoder.eval()
111
- status_messages.append("βœ… Encoder loaded")
112
-
113
- # Isolation Forest
114
- self.iforest, self.scaler_iforest = load_model(str(config.IFOREST_PATH))
115
- status_messages.append("βœ… Isolation Forest loaded")
116
-
117
- # K-Means
118
- self.kmeans, self.scaler_kmeans = load_kmeans_model(str(config.KMEANS_PATH))
119
- status_messages.append("βœ… K-Means loaded")
120
-
121
- # GATE Model
122
- import joblib
123
- self.gate_model = joblib.load(config.GATE_MODEL_PKL)
124
- self.gate_scaler = joblib.load(config.GATE_SCALER_PATH)
125
- status_messages.append("βœ… GATE model loaded")
126
-
127
- self.models_loaded = True
128
- status_messages.append("\nπŸŽ‰ All models loaded successfully!")
129
-
130
- except Exception as e:
131
- status_messages.append(f"❌ Error loading models: {e}")
132
- import traceback
133
- status_messages.append(traceback.format_exc())
134
-
135
- return "\n".join(status_messages)
136
-
137
- def get_models_dict(self):
138
- """Return models as dictionary"""
139
- return {
140
- 'autoencoder': self.autoencoder,
141
- 'encoder': self.encoder,
142
- 'iforest': self.iforest,
143
- 'scaler_iforest': self.scaler_iforest,
144
- 'kmeans': self.kmeans,
145
- 'scaler_kmeans': self.scaler_kmeans,
146
- 'gate_model': self.gate_model,
147
- 'gate_scaler': self.gate_scaler
148
- }
149
-
150
- # Global model manager
151
- model_manager = ModelManager()
152
-
153
- # ==============================================================================
154
- # DATA MANAGER
155
- # ==============================================================================
156
-
157
- class DataManager:
158
- """Manages AOI data and patches"""
159
-
160
- def __init__(self):
161
- self.aoi_list = self._discover_aois()
162
- self.current_aoi = None
163
- self.current_patches = None
164
- self.current_metadata = None
165
- self.current_unified_matrix = None
166
-
167
- def _discover_aois(self):
168
- """Discover all available AOIs"""
169
- if not config.PATCHES_DIR.exists():
170
- return []
171
-
172
- patch_files = list(config.PATCHES_DIR.glob("AOI_*_all_patches.npz"))
173
- aoi_names = sorted([f.stem.replace('_all_patches', '') for f in patch_files])
174
- return aoi_names
175
-
176
- def load_aoi(self, aoi_name: str):
177
- """Load patches and unified matrix for an AOI"""
178
- try:
179
- # Load patches
180
- patches_file = config.PATCHES_DIR / f"{aoi_name}_all_patches.npz"
181
- self.current_patches, self.current_metadata = load_patches(patches_file)
182
-
183
- # Load unified probability matrix
184
- matrix_file = config.UNIFIED_PROB_DIR / f"{aoi_name}_unified_prob_matrix.npz"
185
- if matrix_file.exists():
186
- self.current_unified_matrix, _, _ = load_unified_probability_matrix(
187
- aoi_name, config.UNIFIED_PROB_DIR
188
- )
189
- else:
190
- self.current_unified_matrix = None
191
-
192
- self.current_aoi = aoi_name
193
-
194
- return f"βœ… Loaded {aoi_name}: {len(self.current_patches)} patches"
195
-
196
- except Exception as e:
197
- return f"❌ Error loading {aoi_name}: {e}"
198
-
199
- def get_aoi_info(self):
200
- """Get current AOI information"""
201
- if self.current_aoi is None:
202
- return "No AOI loaded"
203
-
204
- info = [
205
- f"**AOI:** {self.current_aoi}",
206
- f"**Total Patches:** {len(self.current_patches)}",
207
- f"**Patch Shape:** {self.current_patches[0].shape}",
208
- ]
209
-
210
- if self.current_unified_matrix is not None:
211
- info.append(f"**Unified Matrix:** {self.current_unified_matrix.shape}")
212
- info.append(f"**Channels:** 5 (incl. GATE)")
213
-
214
- return "\n".join(info)
215
-
216
- def get_patch(self, patch_idx: int):
217
- """Get specific patch"""
218
- if self.current_patches is None or patch_idx >= len(self.current_patches):
219
- return None, None
220
-
221
- return self.current_patches[patch_idx], self.current_metadata[patch_idx]
222
-
223
- # Global data manager
224
- data_manager = DataManager()
225
-
226
- # ==============================================================================
227
- # VISUALIZATION FUNCTIONS
228
- # ==============================================================================
229
-
230
- def create_patch_visualization(patch: np.ndarray, metadata: Dict):
231
- """Create 2D visualization of patch"""
232
- channel_names = ['DTM', 'Slope', 'Roughness', 'NDVI', 'NDWI', 'FlowAcc', 'FlowDir']
233
-
234
- fig, axes = plt.subplots(2, 4, figsize=(16, 8))
235
- axes = axes.flatten()
236
-
237
- for i in range(7):
238
- ax = axes[i]
239
- data = patch[i]
240
-
241
- # Choose colormap
242
- if i == 0: # DTM
243
- cmap = 'terrain'
244
- elif i == 1: # Slope
245
- cmap = 'YlOrRd'
246
- elif i == 3: # NDVI
247
- cmap = 'RdYlGn'
248
- elif i == 4: # NDWI
249
- cmap = 'Blues'
250
- else:
251
- cmap = 'viridis'
252
-
253
- im = ax.imshow(data, cmap=cmap, interpolation='bilinear')
254
- ax.set_title(f'{channel_names[i]}\nΞΌ={np.nanmean(data):.2f}', fontweight='bold')
255
- ax.axis('off')
256
- plt.colorbar(im, ax=ax, fraction=0.046)
257
-
258
- # Hide last subplot
259
- axes[7].axis('off')
260
-
261
- plt.suptitle(f"Patch {metadata['patch_id']} (row={metadata['row']}, col={metadata['col']})",
262
- fontsize=14, fontweight='bold')
263
- plt.tight_layout()
264
-
265
- # Convert to image
266
- buf = io.BytesIO()
267
- plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
268
- plt.close(fig)
269
- buf.seek(0)
270
-
271
- return Image.open(buf)
272
-
273
- def create_3d_terrain(patch: np.ndarray, metadata: Dict):
274
- """Create 3D terrain visualization"""
275
- dtm = patch[0] # First channel is DTM
276
-
277
- # Create coordinate grids
278
- rows, cols = dtm.shape
279
- x = np.arange(cols)
280
- y = np.arange(rows)
281
- X, Y = np.meshgrid(x, y)
282
-
283
- # Handle NaN
284
- dtm_clean = np.nan_to_num(dtm, nan=np.nanmedian(dtm))
285
-
286
- # Create surface
287
- fig = go.Figure(data=[go.Surface(
288
- z=dtm_clean,
289
- x=X,
290
- y=Y,
291
- colorscale='earth',
292
- colorbar=dict(title='Elevation (m)'),
293
- lighting=dict(
294
- ambient=0.4,
295
- diffuse=0.8,
296
- fresnel=0.2,
297
- specular=0.3,
298
- roughness=0.5
299
- )
300
- )])
301
-
302
- fig.update_layout(
303
- title=f"3D Terrain - {metadata['patch_id']}",
304
- scene=dict(
305
- xaxis_title='X (pixels)',
306
- yaxis_title='Y (pixels)',
307
- zaxis_title='Elevation (m)',
308
- camera=dict(eye=dict(x=1.5, y=1.5, z=1.3)),
309
- aspectmode='manual',
310
- aspectratio=dict(x=1, y=1, z=0.5)
311
- ),
312
- width=800,
313
- height=600
314
- )
315
-
316
- return fig
317
-
318
- def create_probability_channels_viz(unified_matrix: np.ndarray, patch_idx: int):
319
- """Visualize all 5 probability channels"""
320
- patch_data = unified_matrix[patch_idx]
321
- channel_names = ['Autoencoder', 'IForest', 'K-Means', 'Arch Similarity', 'GATE Prediction']
322
-
323
- fig, axes = plt.subplots(2, 3, figsize=(15, 10))
324
- axes = axes.flatten()
325
-
326
- for i in range(5):
327
- ax = axes[i]
328
- data = patch_data[:, :, i]
329
-
330
- im = ax.imshow(data, cmap='hot', vmin=0, vmax=1, interpolation='bilinear')
331
- ax.set_title(f'{channel_names[i]}\nΞΌ={data.mean():.3f}, max={data.max():.3f}',
332
- fontweight='bold')
333
- ax.axis('off')
334
- plt.colorbar(im, ax=ax, fraction=0.046)
335
-
336
- # Hide last subplot
337
- axes[5].axis('off')
338
-
339
- plt.suptitle(f'Probability Channels - Patch {patch_idx}',
340
- fontsize=14, fontweight='bold')
341
- plt.tight_layout()
342
-
343
- # Convert to image
344
- buf = io.BytesIO()
345
- plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
346
- plt.close(fig)
347
- buf.seek(0)
348
-
349
- return Image.open(buf)
350
-
351
- def create_gate_heatmap(unified_matrix: np.ndarray, metadata: List[Dict], aoi_name: str):
352
- """Create GATE prediction heatmap"""
353
- # Determine AOI shape
354
- max_row = max(m['row'] + 64 for m in metadata)
355
- max_col = max(m['col'] + 64 for m in metadata)
356
- aoi_shape = (max_row, max_col)
357
-
358
- # Generate heatmap
359
- heatmap_bytes = generate_gate_prediction_heatmap(
360
- unified_matrix, metadata, aoi_shape, aoi_name, threshold=0.5
361
- )
362
-
363
- return Image.open(io.BytesIO(heatmap_bytes))
364
-
365
- def create_top_candidates(unified_matrix: np.ndarray, metadata: List[Dict],
366
- patches: np.ndarray, threshold: float = 0.5):
367
- """Visualize top archaeological candidates"""
368
- # Get patches as dict
369
- patches_dict = {'dtm': patches[:, 0, :, :]}
370
-
371
- candidates_bytes = generate_gate_positive_patches_visualization(
372
- unified_matrix, metadata, patches_dict, threshold=threshold, top_n=16
373
- )
374
-
375
- return Image.open(io.BytesIO(candidates_bytes))
376
-
377
- # ==============================================================================
378
- # GRADIO INTERFACE FUNCTIONS
379
- # ==============================================================================
380
-
381
- def load_models_ui():
382
- """Load models button handler"""
383
- status = model_manager.load_all_models()
384
-
385
- if model_manager.models_loaded:
386
- return status, gr.update(interactive=True) # Enable analysis tab
387
- else:
388
- return status, gr.update(interactive=False)
389
-
390
- def load_aoi_ui(aoi_name):
391
- """Load AOI handler"""
392
- status = data_manager.load_aoi(aoi_name)
393
- info = data_manager.get_aoi_info()
394
-
395
- # Update patch slider range
396
- if data_manager.current_patches is not None:
397
- max_patches = len(data_manager.current_patches) - 1
398
- return status, info, gr.update(maximum=max_patches, value=0)
399
-
400
- return status, info, gr.update()
401
-
402
- def view_patch_ui(patch_idx):
403
- """View specific patch"""
404
- if data_manager.current_patches is None:
405
- return None, None, None, "⚠️ Load an AOI first"
406
-
407
- patch, metadata = data_manager.get_patch(patch_idx)
408
-
409
- if patch is None:
410
- return None, None, None, "❌ Invalid patch index"
411
-
412
- # Create visualizations
413
- img_2d = create_patch_visualization(patch, metadata)
414
- fig_3d = create_3d_terrain(patch, metadata)
415
-
416
- # Probability channels if available
417
- prob_img = None
418
- if data_manager.current_unified_matrix is not None:
419
- prob_img = create_probability_channels_viz(
420
- data_manager.current_unified_matrix, patch_idx
421
- )
422
-
423
- info = f"**Patch {patch_idx}:** {metadata['patch_id']}\n" \
424
- f"**Position:** (row={metadata['row']}, col={metadata['col']})"
425
-
426
- return img_2d, fig_3d, prob_img, info
427
-
428
- def analyze_aoi_gate(threshold):
429
- """Analyze full AOI with GATE"""
430
- if data_manager.current_unified_matrix is None:
431
- return None, None, "⚠️ No unified matrix available for this AOI"
432
-
433
- # Generate heatmap
434
- heatmap = create_gate_heatmap(
435
- data_manager.current_unified_matrix,
436
- data_manager.current_metadata,
437
- data_manager.current_aoi
438
- )
439
-
440
- # Generate top candidates
441
- candidates = create_top_candidates(
442
- data_manager.current_unified_matrix,
443
- data_manager.current_metadata,
444
- data_manager.current_patches,
445
- threshold=threshold
446
- )
447
-
448
- # Statistics
449
- stats = generate_gate_statistics_json(
450
- data_manager.current_unified_matrix,
451
- threshold=threshold
452
- )
453
-
454
- stats_text = f"""
455
- **GATE Analysis Results**
456
-
457
- - Total Patches: {stats['total_patches']}
458
- - Positive Predictions: {stats['positive_patches']} ({stats['positive_percentage']:.1f}%)
459
- - Mean GATE Score: {stats['mean_gate_score']:.3f}
460
- - Max GATE Score: {stats['max_gate_score']:.3f}
461
- - Threshold: {stats['threshold']}
462
- """
463
-
464
- return heatmap, candidates, stats_text
465
-
466
- def batch_process_ui(start_aoi, end_aoi, progress=gr.Progress()):
467
- """Batch process multiple AOIs"""
468
- if not model_manager.models_loaded:
469
- return "❌ Load models first"
470
-
471
- aoi_list = data_manager.aoi_list[start_aoi:end_aoi+1]
472
- results = []
473
-
474
- for i, aoi_name in enumerate(progress.tqdm(aoi_list, desc="Processing AOIs")):
475
- try:
476
- data_manager.load_aoi(aoi_name)
477
-
478
- if data_manager.current_unified_matrix is not None:
479
- stats = generate_gate_statistics_json(
480
- data_manager.current_unified_matrix, threshold=0.5
481
- )
482
-
483
- results.append({
484
- 'AOI': aoi_name,
485
- 'Patches': stats['total_patches'],
486
- 'Positives': stats['positive_patches'],
487
- 'Positive %': f"{stats['positive_percentage']:.1f}%",
488
- 'Mean Score': f"{stats['mean_gate_score']:.3f}"
489
- })
490
- except Exception as e:
491
- results.append({
492
- 'AOI': aoi_name,
493
- 'Error': str(e)
494
- })
495
-
496
- df = pd.DataFrame(results)
497
- return df
498
-
499
- # ==============================================================================
500
- # BUILD GRADIO INTERFACE
501
- # ==============================================================================
502
-
503
- def build_interface():
504
- """Build the main Gradio interface"""
505
-
506
- with gr.Blocks(theme=config.THEME, title="Archaeological Site Detection") as app:
507
-
508
- # Header
509
- gr.Markdown("""
510
- # πŸ›οΈ Archaeological Site Detection System
511
- ### SONAR 2.0 - Multi-Model Analysis with GATE Predictions
512
-
513
- Explore archaeological sites using advanced AI detection across multiple AOIs.
514
- """)
515
-
516
- # Model Loading Section
517
- with gr.Row():
518
- with gr.Column(scale=3):
519
- gr.Markdown("### πŸ€– Step 1: Load AI Models")
520
- gr.Markdown("Load all trained models (Autoencoder, Isolation Forest, K-Means, GATE)")
521
-
522
- with gr.Column(scale=1):
523
- load_models_btn = gr.Button("πŸ”„ Load Models", variant="primary", size="lg")
524
-
525
- model_status = gr.Textbox(label="Model Status", lines=8, max_lines=10)
526
-
527
- gr.Markdown("---")
528
-
529
- # Main Interface
530
- with gr.Tabs():
531
-
532
- # TAB 1: Patch Explorer
533
- with gr.Tab("πŸ” Patch Explorer"):
534
- gr.Markdown("### Explore individual patches from any AOI")
535
-
536
- with gr.Row():
537
- with gr.Column(scale=1):
538
- aoi_dropdown = gr.Dropdown(
539
- choices=data_manager.aoi_list,
540
- label="Select AOI",
541
- value=data_manager.aoi_list[0] if data_manager.aoi_list else None
542
- )
543
-
544
- load_aoi_btn = gr.Button("πŸ“‚ Load AOI", variant="secondary")
545
-
546
- aoi_status = gr.Textbox(label="Load Status", lines=2)
547
- aoi_info = gr.Markdown("*No AOI loaded*")
548
-
549
- gr.Markdown("---")
550
-
551
- patch_slider = gr.Slider(
552
- minimum=0, maximum=100, step=1, value=0,
553
- label="Select Patch"
554
- )
555
-
556
- view_patch_btn = gr.Button("πŸ‘οΈ View Patch", variant="primary")
557
-
558
- patch_info = gr.Markdown("*Select a patch*")
559
-
560
- with gr.Column(scale=2):
561
- with gr.Tabs():
562
- with gr.Tab("πŸ“Š 2D View"):
563
- patch_2d = gr.Image(label="Patch Layers", type="pil")
564
-
565
- with gr.Tab("πŸ”οΈ 3D Terrain"):
566
- patch_3d = gr.Plot(label="3D Visualization")
567
-
568
- with gr.Tab("🎯 Probability Channels"):
569
- prob_channels = gr.Image(label="5 Probability Channels", type="pil")
570
-
571
- # TAB 2: GATE Analysis
572
- with gr.Tab("🎯 GATE Analysis", id="gate_tab"):
573
- gr.Markdown("### Full AOI Analysis with GATE Model")
574
-
575
- with gr.Row():
576
- threshold_slider = gr.Slider(
577
- minimum=0.0, maximum=1.0, value=0.5, step=0.05,
578
- label="GATE Threshold (higher = stricter)"
579
- )
580
-
581
- analyze_btn = gr.Button("πŸ”₯ Analyze AOI", variant="primary", size="lg")
582
-
583
- gate_stats = gr.Markdown("*Run analysis to see statistics*")
584
-
585
- with gr.Row():
586
- with gr.Column():
587
- gate_heatmap = gr.Image(label="GATE Prediction Heatmap", type="pil")
588
-
589
- with gr.Column():
590
- gate_candidates = gr.Image(label="Top Archaeological Candidates", type="pil")
591
-
592
- # TAB 3: Batch Processing
593
- with gr.Tab("⚑ Batch Processing"):
594
- gr.Markdown("### Process Multiple AOIs")
595
-
596
- with gr.Row():
597
- start_idx = gr.Number(label="Start AOI Index", value=0, precision=0)
598
- end_idx = gr.Number(label="End AOI Index", value=10, precision=0)
599
-
600
- batch_btn = gr.Button("πŸš€ Process Batch", variant="primary")
601
-
602
- batch_results = gr.Dataframe(
603
- headers=["AOI", "Patches", "Positives", "Positive %", "Mean Score"],
604
- label="Batch Results"
605
- )
606
-
607
- # TAB 4: Documentation
608
- with gr.Tab("πŸ“– Documentation"):
609
- gr.Markdown("""
610
- ## System Overview
611
-
612
- This system combines multiple AI models for archaeological site detection:
613
-
614
- ### Models
615
- 1. **Autoencoder (ResUNet)** - Detects anomalies via reconstruction error
616
- 2. **Isolation Forest** - Identifies outliers in latent space
617
- 3. **K-Means Clustering** - Matches against known archaeological signatures
618
- 4. **Archaeological Similarity** - Compares against verified sites
619
- 5. **GATE Model** - Meta-learner that combines all models for final prediction
620
-
621
- ### Workflow
622
- 1. Load models (one-time operation)
623
- 2. Select an AOI to analyze
624
- 3. Explore individual patches or run full AOI analysis
625
- 4. Review GATE predictions to identify high-probability archaeological sites
626
- 5. Batch process multiple AOIs for comprehensive surveys
627
-
628
- ### Channels
629
- The system uses 7 input channels:
630
- - **DTM**: Digital Terrain Model (LiDAR)
631
- - **Slope**: Terrain slope
632
- - **Roughness**: Surface roughness
633
- - **NDVI**: Normalized Difference Vegetation Index
634
- - **NDWI**: Normalized Difference Water Index
635
- - **Flow Accumulation**: Hydrological flow
636
- - **Flow Direction**: Water flow direction
637
-
638
- ### Output
639
- - **Probability Maps**: 5 probability channels (4 models + GATE)
640
- - **Heatmaps**: Full AOI visualization of predictions
641
- - **Candidate List**: Top archaeological site candidates
642
-
643
- ### Thresholding
644
- - **0.3-0.5**: Low threshold - more candidates, higher false positives
645
- - **0.5-0.7**: Balanced - recommended for initial surveys
646
- - **0.7-0.9**: High threshold - fewer candidates, high confidence
647
- """)
648
-
649
- # Event Handlers
650
- load_models_btn.click(
651
- fn=load_models_ui,
652
- outputs=[model_status, gr.Tabs.update(selected="gate_tab")]
653
- )
654
-
655
- load_aoi_btn.click(
656
- fn=load_aoi_ui,
657
- inputs=[aoi_dropdown],
658
- outputs=[aoi_status, aoi_info, patch_slider]
659
- )
660
-
661
- view_patch_btn.click(
662
- fn=view_patch_ui,
663
- inputs=[patch_slider],
664
- outputs=[patch_2d, patch_3d, prob_channels, patch_info]
665
- )
666
-
667
- analyze_btn.click(
668
- fn=analyze_aoi_gate,
669
- inputs=[threshold_slider],
670
- outputs=[gate_heatmap, gate_candidates, gate_stats]
671
- )
672
-
673
- batch_btn.click(
674
- fn=batch_process_ui,
675
- inputs=[start_idx, end_idx],
676
- outputs=[batch_results]
677
- )
678
-
679
- return app
680
-
681
- # ==============================================================================
682
- # MAIN
683
- # ==============================================================================
684
-
685
- if __name__ == "__main__":
686
- print("πŸ›οΈ Archaeological Site Detection System")
687
- print("=" * 60)
688
- print(f"Device: {config.DEVICE}")
689
- print(f"Discovered AOIs: {len(data_manager.aoi_list)}")
690
- print("=" * 60)
691
-
692
- app = build_interface()
693
-
694
- # Launch with options
695
- app.launch(
696
- server_name="0.0.0.0", # Allow external access
697
- server_port=7860,
698
- share=False, # Set to True for public URL
699
- show_error=True,
700
- show_api=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
701
  )
 
1
+ """
2
+ SONAR 2.0 - Archaeological Site Detection
3
+ Beautiful Interactive Interface with Full AOI Visualization
4
+ """
5
+
6
+ import gradio as gr
7
+ import numpy as np
8
+ import matplotlib.pyplot as plt
9
+ import matplotlib
10
+ matplotlib.use('Agg')
11
+ import plotly.graph_objects as go
12
+ from pathlib import Path
13
+ import torch
14
+ import io
15
+ from PIL import Image
16
+ import warnings
17
+ import zipfile
18
+ from scipy.ndimage import sobel, gaussian_filter
19
+ from matplotlib import cm
20
+ import rasterio
21
+ from rasterio.transform import rowcol
22
+ import folium
23
+ import base64
24
+ warnings.filterwarnings('ignore')
25
+
26
+ from utils import (
27
+ ResUNetAutoencoder, ResUNetEncoder, load_patches,
28
+ load_model, load_kmeans_model, load_unified_probability_matrix
29
+ )
30
+
31
+ # ==============================================================================
32
+ # CONFIGURATION
33
+ # ==============================================================================
34
+
35
+ class Config:
36
+ AUTOENCODER_PATH = Path('models/best_model_aoi.pth')
37
+ ENCODER_DIM = 128
38
+ IFOREST_PATH = Path('models/isolation_forest_model_128dim.pkl')
39
+ KMEANS_PATH = Path('models/kmeans_model_128dim.pkl')
40
+ GATE_MODEL_PKL = Path('models/gate_mlp_model.pkl')
41
+ GATE_SCALER_PATH = Path('models/gate_scaler.pkl')
42
+
43
+ DATA_BASE = Path('Test_dataset')
44
+ PATCHES_DIR = Path('patches_final_file')
45
+ UNIFIED_PROB_DIR = Path('test_unified_probablity_matrices_with_gate')
46
+
47
+ MODELS_ZIP = Path('models.zip')
48
+ PATCHES_ZIP = Path('patches_final_file.zip')
49
+ UNIFIED_PROB_ZIP = Path('test_unified_probablity_matrices_with_gate.zip')
50
+ DATASET_ZIP = Path('Test_dataset.zip')
51
+
52
+ BATCH_SIZE = 32
53
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
54
+
55
+ @staticmethod
56
+ def extract_data_files():
57
+ print("\n" + "="*50)
58
+ print("Starting SONAR 2.0...")
59
+ print("="*50)
60
+
61
+ if Config.MODELS_ZIP.exists() and (not Path('models').exists() or not any(Path('models').iterdir())):
62
+ print("Extracting models...")
63
+ with zipfile.ZipFile(Config.MODELS_ZIP, 'r') as zip_ref:
64
+ zip_ref.extractall('.')
65
+
66
+ if Config.PATCHES_ZIP.exists() and (not Config.PATCHES_DIR.exists() or not any(Config.PATCHES_DIR.iterdir())):
67
+ print("Extracting patches...")
68
+ with zipfile.ZipFile(Config.PATCHES_ZIP, 'r') as zip_ref:
69
+ zip_ref.extractall('.')
70
+
71
+ if Config.UNIFIED_PROB_ZIP.exists() and (not Config.UNIFIED_PROB_DIR.exists() or not any(Config.UNIFIED_PROB_DIR.iterdir())):
72
+ print("Extracting probability matrices...")
73
+ with zipfile.ZipFile(Config.UNIFIED_PROB_ZIP, 'r') as zip_ref:
74
+ zip_ref.extractall('.')
75
+
76
+ if Config.DATASET_ZIP.exists() and not Config.DATA_BASE.exists():
77
+ print("Extracting Test_dataset.zip...")
78
+ with zipfile.ZipFile(Config.DATASET_ZIP, 'r') as zip_ref:
79
+ zip_ref.extractall('.')
80
+ print("βœ“ Test dataset extracted")
81
+
82
+ print("Ready!\n")
83
+
84
+ config = Config()
85
+
86
+ # ==============================================================================
87
+ # DATA MANAGER
88
+ # ==============================================================================
89
+
90
+ class DataManager:
91
+ def __init__(self):
92
+ self.aoi_list = []
93
+ self.current_aoi = None
94
+ self.current_patches = None
95
+ self.current_metadata = None
96
+ self.current_unified_matrix = None
97
+ self.full_dtm = None
98
+ self.reference_transform = None
99
+ self.reference_shape = None
100
+ self.reference_bounds = None
101
+ self.reference_crs = None
102
+
103
+ def discover_aois(self):
104
+ if not config.PATCHES_DIR.exists():
105
+ return []
106
+ patch_files = list(config.PATCHES_DIR.glob("AOI_*_all_patches.npz"))
107
+ aoi_names = sorted([f.stem.replace('_all_patches', '') for f in patch_files])
108
+ self.aoi_list = aoi_names
109
+ return aoi_names
110
+
111
+ def load_original_raster(self, aoi_name: str):
112
+ """Load original DTM from Test_dataset"""
113
+ try:
114
+ meta_dir = config.DATA_BASE / aoi_name / 'meta'
115
+ if not meta_dir.exists():
116
+ return None
117
+
118
+ tif_files = list(meta_dir.glob('*.tif'))
119
+ if not tif_files:
120
+ return None
121
+
122
+ dtm_path = tif_files[0]
123
+ print(f"βœ“ Loading DTM: {dtm_path.name}")
124
+
125
+ with rasterio.open(dtm_path) as src:
126
+ dtm = src.read(1).astype(np.float32)
127
+ if src.nodata is not None:
128
+ dtm[dtm == src.nodata] = np.nan
129
+
130
+ self.reference_transform = src.transform
131
+ self.reference_shape = (src.height, src.width)
132
+ self.reference_bounds = src.bounds
133
+ self.reference_crs = src.crs
134
+
135
+ return dtm
136
+ except Exception as e:
137
+ print(f"Error loading raster: {e}")
138
+ return None
139
+
140
+ def load_aoi(self, aoi_name: str):
141
+ try:
142
+ patches_file = config.PATCHES_DIR / f"{aoi_name}_all_patches.npz"
143
+ self.current_patches, self.current_metadata = load_patches(patches_file)
144
+
145
+ matrix_file = config.UNIFIED_PROB_DIR / f"{aoi_name}_unified_prob_matrix.npz"
146
+ if matrix_file.exists():
147
+ self.current_unified_matrix, _, _ = load_unified_probability_matrix(
148
+ aoi_name, config.UNIFIED_PROB_DIR
149
+ )
150
+
151
+ self.full_dtm = self.load_original_raster(aoi_name)
152
+
153
+ if self.full_dtm is not None:
154
+ max_row = max(m['row'] + 64 for m in self.current_metadata)
155
+ max_col = max(m['col'] + 64 for m in self.current_metadata)
156
+ dtm_height, dtm_width = self.full_dtm.shape
157
+
158
+ if max_row > dtm_height or max_col > dtm_width:
159
+ print(f"⚠️ Patch bounds exceed raster, falling back to reconstruction")
160
+ self.full_dtm = None
161
+
162
+ self.current_aoi = aoi_name
163
+ return f"βœ… {aoi_name} loaded ({len(self.current_patches)} patches)"
164
+ except Exception as e:
165
+ return f"❌ Error: {e}"
166
+
167
+ def get_patch(self, patch_idx: int):
168
+ if self.current_patches is None or patch_idx >= len(self.current_patches):
169
+ return None, None
170
+ return self.current_patches[patch_idx], self.current_metadata[patch_idx]
171
+
172
+ def find_patch_at_pixel(self, row: int, col: int):
173
+ for idx, meta in enumerate(self.current_metadata):
174
+ patch_row, patch_col = meta['row'], meta['col']
175
+ if (patch_row <= row < patch_row + 64 and patch_col <= col < patch_col + 64):
176
+ return idx
177
+ return None
178
+
179
+ data_manager = DataManager()
180
+
181
+ # ==============================================================================
182
+ # MAP GENERATION WITH MULTIPLE OVERLAYS (LIKE STREAMLIT VERSION)
183
+ # ==============================================================================
184
+
185
+ def create_interactive_map(aoi_name, threshold=0.5):
186
+ """
187
+ Create beautiful Folium map with multiple terrain overlays
188
+ Matches the Streamlit visualization quality
189
+ """
190
+
191
+ if data_manager.full_dtm is None or data_manager.reference_bounds is None:
192
+ # Fallback: create a simple map showing patches without terrain overlays
193
+ if data_manager.current_patches is None:
194
+ return "<div style='padding: 20px; text-align: center;'>⚠️ Load an AOI first</div>"
195
+
196
+ # Calculate approximate center from patches
197
+ if data_manager.current_metadata:
198
+ rows = [m['row'] for m in data_manager.current_metadata]
199
+ cols = [m['col'] for m in data_manager.current_metadata]
200
+ center_row = (min(rows) + max(rows)) / 2
201
+ center_col = (min(cols) + max(cols)) / 2
202
+
203
+ # Use approximate lat/lon (this is a fallback)
204
+ center_lat, center_lon = 0, 0 # Will be set properly below
205
+
206
+ m = folium.Map(
207
+ location=[center_lat, center_lon],
208
+ zoom_start=14,
209
+ tiles='OpenStreetMap'
210
+ )
211
+
212
+ # Add patch markers
213
+ if data_manager.current_unified_matrix is not None:
214
+ gate_channel = data_manager.current_unified_matrix[:, :, :, 4]
215
+
216
+ for idx, meta in enumerate(data_manager.current_metadata):
217
+ patch_score = np.mean(gate_channel[idx])
218
+
219
+ if patch_score >= threshold:
220
+ # For fallback, just use patch indices as approximate locations
221
+ popup_html = f"""
222
+ <b>πŸ”΄ Anomaly Detected</b><br>
223
+ <hr>
224
+ <b>Score:</b> {patch_score:.3f}<br>
225
+ <b>Patch ID:</b> {idx}<br>
226
+ <hr>
227
+ <i>⚠️ DTM not available - using patch coordinates</i>
228
+ """
229
+
230
+ # This is approximate - will work better with actual DTM
231
+ lat_approx = meta['row'] / 1000.0
232
+ lon_approx = meta['col'] / 1000.0
233
+
234
+ folium.CircleMarker(
235
+ location=[lat_approx, lon_approx],
236
+ radius=8,
237
+ popup=folium.Popup(popup_html, max_width=300),
238
+ tooltip=f"Patch {idx} - Score: {patch_score:.3f}",
239
+ color='red',
240
+ fill=True,
241
+ fillColor='orange',
242
+ fillOpacity=0.7,
243
+ weight=2
244
+ ).add_to(m)
245
+
246
+ folium.LayerControl().add_to(m)
247
+
248
+ return m._repr_html_()
249
+
250
+ return "<div style='padding: 20px; text-align: center; background: #fff3cd; border-radius: 8px;'>⚠️ No DTM data available for this AOI. Map visualization limited.</div>"
251
+
252
+ bounds = data_manager.reference_bounds
253
+ center_lat = (bounds.bottom + bounds.top) / 2
254
+ center_lon = (bounds.left + bounds.right) / 2
255
+
256
+ # Create base map
257
+ m = folium.Map(
258
+ location=[center_lat, center_lon],
259
+ zoom_start=14,
260
+ tiles='OpenStreetMap'
261
+ )
262
+
263
+ # Add center marker
264
+ folium.Marker(
265
+ location=[center_lat, center_lon],
266
+ popup=f'AOI Center<br>Lat: {center_lat:.6f}<br>Lon: {center_lon:.6f}',
267
+ tooltip='AOI Center',
268
+ icon=folium.Icon(color='blue', icon='info-sign')
269
+ ).add_to(m)
270
+
271
+ # Generate terrain overlays
272
+ dtm = data_manager.full_dtm
273
+ valid_mask = ~np.isnan(dtm)
274
+
275
+ if valid_mask.any():
276
+ dtm_filled = dtm.copy()
277
+ dtm_filled[~valid_mask] = np.nanmedian(dtm)
278
+
279
+ # ============================================================
280
+ # LAYER 1: Local Relief Model (ARCHAEOLOGICAL GOLD!)
281
+ # ============================================================
282
+ dtm_smooth = gaussian_filter(dtm_filled, sigma=10)
283
+ local_relief = dtm_filled - dtm_smooth
284
+
285
+ relief_clipped = np.clip(local_relief, -2, 2)
286
+ relief_norm = (relief_clipped + 2) / 4
287
+
288
+ rdbu_cmap = cm.get_cmap('RdBu_r')
289
+ relief_rgba = rdbu_cmap(relief_norm)
290
+ relief_rgb = (relief_rgba[:, :, :3] * 255).astype(np.uint8)
291
+ relief_rgb[~valid_mask] = [128, 128, 128]
292
+
293
+ img_relief = Image.fromarray(relief_rgb, mode='RGB')
294
+ buffered = io.BytesIO()
295
+ img_relief.save(buffered, format="PNG")
296
+ img_str = base64.b64encode(buffered.getvalue()).decode()
297
+
298
+ folium.raster_layers.ImageOverlay(
299
+ image=f'data:image/png;base64,{img_str}',
300
+ bounds=[[bounds.bottom, bounds.left], [bounds.top, bounds.right]],
301
+ opacity=0.75,
302
+ name='πŸ›οΈ Local Relief Model (Archaeological)',
303
+ overlay=True,
304
+ control=True,
305
+ show=True # DEFAULT ON
306
+ ).add_to(m)
307
+
308
+ # ============================================================
309
+ # LAYER 2: Multi-Directional Hillshade
310
+ # ============================================================
311
+ dx = sobel(dtm_filled, axis=1) / 8.0
312
+ dy = sobel(dtm_filled, axis=0) / 8.0
313
+ slope = np.arctan(np.sqrt(dx**2 + dy**2))
314
+ aspect = np.arctan2(-dy, dx)
315
+
316
+ azimuths = [315, 45, 225, 135]
317
+ altitude = 45
318
+ hillshades = []
319
+
320
+ for az_deg in azimuths:
321
+ azimuth = np.radians(az_deg)
322
+ alt_rad = np.radians(altitude)
323
+ hs = (np.sin(alt_rad) * np.sin(slope) +
324
+ np.cos(alt_rad) * np.cos(slope) *
325
+ np.cos(azimuth - aspect))
326
+ hillshades.append(hs)
327
+
328
+ hillshade_multi = np.mean(hillshades, axis=0)
329
+ hillshade_multi = np.clip(hillshade_multi, -1, 1)
330
+ hillshade_multi = ((hillshade_multi + 1) / 2 * 255).astype(np.uint8)
331
+ hillshade_multi[~valid_mask] = 128
332
+
333
+ hillshade_multi_rgb = np.stack([hillshade_multi, hillshade_multi, hillshade_multi], axis=-1)
334
+
335
+ # Add color tinting
336
+ dtm_norm = (dtm - np.nanpercentile(dtm[valid_mask], 2)) / \
337
+ (np.nanpercentile(dtm[valid_mask], 98) -
338
+ np.nanpercentile(dtm[valid_mask], 2))
339
+ dtm_norm = np.clip(dtm_norm, 0, 1)
340
+
341
+ terrain_cmap = cm.get_cmap('terrain')
342
+ terrain_rgba = terrain_cmap(dtm_norm)
343
+ terrain_rgb = (terrain_rgba[:, :, :3] * 255).astype(np.uint8)
344
+
345
+ hillshade_multi_rgb = (hillshade_multi_rgb * 0.75 + terrain_rgb * 0.25).astype(np.uint8)
346
+ hillshade_multi_rgb[~valid_mask] = [128, 128, 128]
347
+
348
+ img_multi = Image.fromarray(hillshade_multi_rgb, mode='RGB')
349
+ buffered_multi = io.BytesIO()
350
+ img_multi.save(buffered_multi, format="PNG")
351
+ img_str_multi = base64.b64encode(buffered_multi.getvalue()).decode()
352
+
353
+ folium.raster_layers.ImageOverlay(
354
+ image=f'data:image/png;base64,{img_str_multi}',
355
+ bounds=[[bounds.bottom, bounds.left], [bounds.top, bounds.right]],
356
+ opacity=0.7,
357
+ name='πŸ—» Multi-Directional Hillshade',
358
+ overlay=True,
359
+ control=True,
360
+ show=False
361
+ ).add_to(m)
362
+
363
+ # ============================================================
364
+ # LAYER 3: Standard Terrain
365
+ # ============================================================
366
+ dtm_norm_basic = (dtm - np.nanpercentile(dtm[valid_mask], 2)) / \
367
+ (np.nanpercentile(dtm[valid_mask], 98) -
368
+ np.nanpercentile(dtm[valid_mask], 2))
369
+ dtm_norm_basic = np.clip(dtm_norm_basic, 0, 1)
370
+
371
+ terrain_basic_rgba = terrain_cmap(dtm_norm_basic)
372
+ terrain_basic_rgb = (terrain_basic_rgba[:, :, :3] * 255).astype(np.uint8)
373
+ terrain_basic_rgb[~valid_mask] = [128, 128, 128]
374
+
375
+ img_basic = Image.fromarray(terrain_basic_rgb, mode='RGB')
376
+ buffered_basic = io.BytesIO()
377
+ img_basic.save(buffered_basic, format="PNG")
378
+ img_str_basic = base64.b64encode(buffered_basic.getvalue()).decode()
379
+
380
+ folium.raster_layers.ImageOverlay(
381
+ image=f'data:image/png;base64,{img_str_basic}',
382
+ bounds=[[bounds.bottom, bounds.left], [bounds.top, bounds.right]],
383
+ opacity=0.65,
384
+ name='🌍 Standard Terrain',
385
+ overlay=True,
386
+ control=True,
387
+ show=False
388
+ ).add_to(m)
389
+
390
+ # Add anomaly markers
391
+ if data_manager.current_unified_matrix is not None:
392
+ gate_channel = data_manager.current_unified_matrix[:, :, :, 4]
393
+
394
+ for idx, meta in enumerate(data_manager.current_metadata):
395
+ patch_score = np.mean(gate_channel[idx])
396
+
397
+ if patch_score >= threshold:
398
+ row, col = meta['row'], meta['col']
399
+ center_row = row + 32
400
+ center_col = col + 32
401
+ x, y = rasterio.transform.xy(data_manager.reference_transform, center_row, center_col)
402
+
403
+ if data_manager.reference_crs != 'EPSG:4326':
404
+ from rasterio.warp import transform as transform_coords
405
+ lon, lat = transform_coords(data_manager.reference_crs, 'EPSG:4326', [x], [y])
406
+ lat, lon = lat[0], lon[0]
407
+ else:
408
+ lat, lon = y, x
409
+
410
+ popup_html = f"""
411
+ <b>πŸ”΄ Anomaly Detected</b><br>
412
+ <hr>
413
+ <b>Score:</b> {patch_score:.3f}<br>
414
+ <b>Patch ID:</b> {idx}<br>
415
+ <b>Location:</b><br>
416
+ Lat: {lat:.6f}<br>
417
+ Lon: {lon:.6f}<br>
418
+ <hr>
419
+ <i>Click map at this location to inspect</i>
420
+ """
421
+
422
+ folium.CircleMarker(
423
+ location=[lat, lon],
424
+ radius=8,
425
+ popup=folium.Popup(popup_html, max_width=300),
426
+ tooltip=f"Anomaly Score: {patch_score:.3f}",
427
+ color='red',
428
+ fill=True,
429
+ fillColor='orange',
430
+ fillOpacity=0.7,
431
+ weight=2
432
+ ).add_to(m)
433
+
434
+ # Add boundary
435
+ folium.Rectangle(
436
+ bounds=[[bounds.bottom, bounds.left], [bounds.top, bounds.right]],
437
+ color='red',
438
+ fill=False,
439
+ weight=2,
440
+ popup=f'{aoi_name} boundary'
441
+ ).add_to(m)
442
+
443
+ folium.LayerControl().add_to(m)
444
+
445
+ return m._repr_html_()
446
+
447
+ # ==============================================================================
448
+ # PATCH VISUALIZATION
449
+ # ==============================================================================
450
+
451
+ def create_patch_viz(patch: np.ndarray, metadata: dict):
452
+ """Beautiful 2D patch visualization"""
453
+ channel_names = ['DTM', 'Slope', 'Roughness', 'NDVI', 'NDWI', 'Flow Acc', 'Flow Dir']
454
+
455
+ fig, axes = plt.subplots(2, 4, figsize=(16, 8), facecolor='white')
456
+ axes = axes.flatten()
457
+
458
+ for i in range(7):
459
+ ax = axes[i]
460
+ data = patch[i]
461
+ cmap = ['terrain', 'YlOrRd', 'viridis', 'RdYlGn', 'Blues', 'cividis', 'twilight'][i]
462
+
463
+ im = ax.imshow(data, cmap=cmap, interpolation='bilinear')
464
+ ax.set_title(channel_names[i], fontsize=12, fontweight='bold', pad=8)
465
+ ax.axis('off')
466
+ plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
467
+
468
+ axes[7].axis('off')
469
+
470
+ plt.suptitle(f"Patch {metadata['patch_id']} | Row {metadata['row']} | Col {metadata['col']}",
471
+ fontsize=16, fontweight='bold', y=0.98)
472
+ plt.tight_layout()
473
+
474
+ buf = io.BytesIO()
475
+ plt.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='white')
476
+ plt.close(fig)
477
+ buf.seek(0)
478
+
479
+ return Image.open(buf)
480
+
481
+ def create_3d_terrain(patch: np.ndarray, metadata: dict):
482
+ """Enhanced 3D terrain visualization"""
483
+ dtm = patch[0]
484
+ dtm_clean = np.nan_to_num(dtm, nan=np.nanmedian(dtm))
485
+
486
+ rows, cols = dtm.shape
487
+ x, y = np.arange(cols), np.arange(rows)
488
+ X, Y = np.meshgrid(x, y)
489
+
490
+ fig = go.Figure(data=[go.Surface(
491
+ z=dtm_clean, x=X, y=Y,
492
+ colorscale='earth',
493
+ showscale=True,
494
+ lighting=dict(
495
+ ambient=0.4,
496
+ diffuse=0.8,
497
+ fresnel=0.2,
498
+ specular=0.3,
499
+ roughness=0.5
500
+ ),
501
+ contours=dict(
502
+ z=dict(
503
+ show=True,
504
+ usecolormap=True,
505
+ highlightcolor="limegreen",
506
+ project=dict(z=True)
507
+ )
508
+ )
509
+ )])
510
+
511
+ fig.update_layout(
512
+ title=f"3D Terrain Β· Patch {metadata['patch_id']}",
513
+ scene=dict(
514
+ xaxis_title='X (pixels)',
515
+ yaxis_title='Y (pixels)',
516
+ zaxis_title='Elevation (m)',
517
+ camera=dict(eye=dict(x=1.5, y=1.5, z=1.3)),
518
+ aspectmode='manual',
519
+ aspectratio=dict(x=1, y=1, z=0.5)
520
+ ),
521
+ height=600,
522
+ margin=dict(l=0, r=0, t=40, b=0)
523
+ )
524
+
525
+ return fig
526
+
527
+ # ==============================================================================
528
+ # UI FUNCTIONS
529
+ # ==============================================================================
530
+
531
+ def load_aoi_and_generate_map(aoi_name, threshold):
532
+ """Load AOI and generate beautiful map"""
533
+ status = data_manager.load_aoi(aoi_name)
534
+
535
+ if "βœ…" in status:
536
+ map_html = create_interactive_map(aoi_name, threshold)
537
+
538
+ # Generate statistics (with safe access)
539
+ stats_html = f"""
540
+ <div style='padding: 15px; background: #f0f7ff; border-radius: 8px; margin: 10px 0;'>
541
+ <h3 style='margin-top: 0; color: #1976d2;'>πŸ“Š AOI Statistics</h3>
542
+ <p><b>Total Patches:</b> {len(data_manager.current_patches)}</p>
543
+ """
544
+
545
+ if data_manager.reference_shape:
546
+ stats_html += f"<p><b>Raster Shape:</b> {data_manager.reference_shape}</p>"
547
+
548
+ if data_manager.reference_crs:
549
+ stats_html += f"<p><b>CRS:</b> {data_manager.reference_crs}</p>"
550
+
551
+ if data_manager.reference_bounds:
552
+ center_lat = (data_manager.reference_bounds.bottom + data_manager.reference_bounds.top) / 2
553
+ center_lon = (data_manager.reference_bounds.left + data_manager.reference_bounds.right) / 2
554
+ stats_html += f"<p><b>Center:</b> {center_lat:.6f}, {center_lon:.6f}</p>"
555
+
556
+ if data_manager.full_dtm is not None:
557
+ stats_html += "<p><b>DTM Status:</b> βœ… Loaded</p>"
558
+ else:
559
+ stats_html += "<p><b>DTM Status:</b> ⚠️ Not available (using reconstructed patches)</p>"
560
+
561
+ stats_html += "</div>"
562
+
563
+ return status, map_html, stats_html, "", None, None
564
+
565
+ return status, "<div style='padding: 20px;'>❌ Failed to load AOI</div>", "", "", None, None
566
+
567
+ def update_threshold(threshold):
568
+ """Update anomaly detection threshold"""
569
+ if data_manager.current_aoi:
570
+ return create_interactive_map(data_manager.current_aoi, threshold)
571
+ return "<div style='padding: 20px;'>⚠️ Load an AOI first</div>"
572
+
573
+ def handle_map_click(lat, lon):
574
+ """Handle click on map - extract and visualize patch"""
575
+ if data_manager.full_dtm is None:
576
+ return "⚠️ Load an AOI first", None, None
577
+
578
+ from rasterio.warp import transform as transform_coords
579
+
580
+ if data_manager.reference_crs != 'EPSG:4326':
581
+ x, y = transform_coords('EPSG:4326', data_manager.reference_crs, [lon], [lat])
582
+ lon, lat = x[0], y[0]
583
+
584
+ row, col = rowcol(data_manager.reference_transform, lon, lat)
585
+ patch_idx = data_manager.find_patch_at_pixel(row, col)
586
+
587
+ if patch_idx is None:
588
+ return f"❌ No patch at ({row}, {col})", None, None
589
+
590
+ patch, metadata = data_manager.get_patch(patch_idx)
591
+
592
+ if patch is None:
593
+ return "❌ Error loading patch", None, None
594
+
595
+ # Generate visualizations
596
+ img_2d = create_patch_viz(patch, metadata)
597
+ fig_3d = create_3d_terrain(patch, metadata)
598
+
599
+ # Get score
600
+ score_text = ""
601
+ if data_manager.current_unified_matrix is not None:
602
+ gate_channel = data_manager.current_unified_matrix[:, :, :, 4]
603
+ patch_score = np.mean(gate_channel[patch_idx])
604
+ score_text = f" | <span style='color: {'red' if patch_score > 0.5 else 'green'}; font-weight: bold;'>Score: {patch_score:.3f}</span>"
605
+
606
+ info = f"<div style='padding: 10px; background: #e8f5e9; border-radius: 5px;'><b>βœ… Patch {patch_idx}</b> | ID: {metadata['patch_id']}{score_text}</div>"
607
+
608
+ return info, img_2d, fig_3d
609
+
610
+ # ==============================================================================
611
+ # BUILD INTERFACE
612
+ # ==============================================================================
613
+
614
+ def build_interface():
615
+
616
+ custom_css = """
617
+ .gradio-container {
618
+ max-width: 1400px !important;
619
+ }
620
+ .map-container {
621
+ border: 2px solid #1976d2;
622
+ border-radius: 8px;
623
+ overflow: hidden;
624
+ }
625
+ """
626
+
627
+ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, title="SONAR 2.0") as app:
628
+
629
+ gr.Markdown("""
630
+ # πŸ—ΊοΈ SONAR 2.0 - Archaeological Anomaly Detection System
631
+ ### Interactive geospatial analysis with multi-layer terrain visualization
632
+ """)
633
+
634
+ with gr.Row():
635
+ # Left sidebar - Controls
636
+ with gr.Column(scale=1):
637
+ gr.Markdown("### πŸŽ›οΈ Control Panel")
638
+
639
+ aoi_dropdown = gr.Dropdown(
640
+ choices=data_manager.aoi_list,
641
+ label="πŸ“ Select Area of Interest (AOI)",
642
+ value=data_manager.aoi_list[0] if data_manager.aoi_list else None,
643
+ info="Choose an AOI to analyze"
644
+ )
645
+
646
+ threshold_slider = gr.Slider(
647
+ minimum=0.0,
648
+ maximum=1.0,
649
+ value=0.5,
650
+ step=0.05,
651
+ label="🎯 Detection Threshold",
652
+ info="Higher = fewer, more confident detections"
653
+ )
654
+
655
+ load_btn = gr.Button("πŸš€ Load AOI & Generate Map", variant="primary", size="lg")
656
+
657
+ status_box = gr.HTML(label="Status")
658
+ stats_box = gr.HTML(label="Statistics")
659
+
660
+ gr.Markdown("""
661
+ ---
662
+ ### πŸ“– How to Use
663
+
664
+ 1. **Select AOI** from dropdown
665
+ 2. **Click "Load AOI"** to generate map
666
+ 3. **Explore layers** using map controls:
667
+ - πŸ›οΈ Local Relief (archaeological features)
668
+ - πŸ—» Multi-directional hillshade
669
+ - 🌍 Standard terrain
670
+ 4. **Click anywhere** on the map to inspect
671
+ 5. **View 2D/3D** visualizations below
672
+
673
+ πŸ”΄ **Red markers** = Detected anomalies
674
+ πŸ”΅ **Blue marker** = AOI center
675
+ """)
676
+
677
+ # Right side - Map
678
+ with gr.Column(scale=3):
679
+ gr.Markdown("### πŸ—ΊοΈ Interactive Map (Click to Inspect)")
680
+ map_display = gr.HTML(
681
+ value="<div style='padding: 40px; text-align: center; background: #f5f5f5; border-radius: 8px;'>πŸ‘ˆ Select an AOI and click 'Load AOI' to view map</div>",
682
+ elem_classes=["map-container"]
683
+ )
684
+
685
+ gr.Markdown("### πŸ” Manual Inspection")
686
+ gr.Markdown("*Enter coordinates from map click (Lat/Lon):*")
687
+
688
+ with gr.Row():
689
+ lat_input = gr.Number(label="Latitude", precision=6, scale=1)
690
+ lon_input = gr.Number(label="Longitude", precision=6, scale=1)
691
+ inspect_btn = gr.Button("πŸ”Ž Inspect Patch", variant="primary", scale=1)
692
+
693
+ patch_info = gr.HTML(label="Patch Information")
694
+
695
+ gr.Markdown("---")
696
+ gr.Markdown("### πŸ“Š Detailed Patch Analysis")
697
+
698
+ with gr.Row():
699
+ with gr.Column():
700
+ gr.Markdown("#### πŸ–ΌοΈ Multi-Channel 2D View")
701
+ patch_2d = gr.Image(label="2D Layer Analysis", type="pil")
702
+
703
+ with gr.Column():
704
+ gr.Markdown("#### πŸ”οΈ 3D Terrain Model")
705
+ terrain_3d = gr.Plot(label="Interactive 3D Visualization")
706
+
707
+ # Event handlers
708
+ load_btn.click(
709
+ fn=load_aoi_and_generate_map,
710
+ inputs=[aoi_dropdown, threshold_slider],
711
+ outputs=[status_box, map_display, stats_box, patch_info, patch_2d, terrain_3d]
712
+ )
713
+
714
+ threshold_slider.change(
715
+ fn=update_threshold,
716
+ inputs=[threshold_slider],
717
+ outputs=[map_display]
718
+ )
719
+
720
+ inspect_btn.click(
721
+ fn=handle_map_click,
722
+ inputs=[lat_input, lon_input],
723
+ outputs=[patch_info, patch_2d, terrain_3d]
724
+ )
725
+
726
+ gr.Markdown("""
727
+ ---
728
+ ### πŸ’‘ Tips
729
+ - Toggle between layers using the map control (top-right)
730
+ - Adjust threshold to see more/fewer anomalies
731
+ - Click directly on red markers for quick inspection
732
+ - Use 3D view to assess terrain relief and features
733
+
734
+ **Powered by SONAR 2.0** | Archaeological AI Detection System
735
+ """)
736
+
737
+ return app
738
+
739
+ # ==============================================================================
740
+ # MAIN
741
+ # ==============================================================================
742
+
743
+ if __name__ == "__main__":
744
+ print("\n" + "="*60)
745
+ print("πŸ—ΊοΈ SONAR 2.0 - Archaeological Site Detection")
746
+ print("="*60)
747
+
748
+ config.extract_data_files()
749
+ data_manager.discover_aois()
750
+
751
+ print(f"\nπŸ“Š System Information:")
752
+ print(f" Device: {config.DEVICE}")
753
+ print(f" Available AOIs: {len(data_manager.aoi_list)}")
754
+ print(f" AOI Names: {', '.join(data_manager.aoi_list)}")
755
+ print("="*60 + "\n")
756
+
757
+ app = build_interface()
758
+
759
+ app.launch(
760
+ server_name="0.0.0.0",
761
+ server_port=7860,
762
+ share=False, # Set to True for public URL
763
+ show_error=True,
764
+ show_api=True
765
  )