Abdulahad79 commited on
Commit
afd2d68
Β·
verified Β·
1 Parent(s): e2aed3c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +718 -716
app.py CHANGED
@@ -1,717 +1,719 @@
1
- import gradio as gr
2
- import numpy as np
3
- import pandas as pd
4
- import torch
5
- import torch.nn as nn
6
- import joblib
7
- import matplotlib.pyplot as plt
8
- from matplotlib.patches import Patch
9
- import matplotlib
10
- from shapely.geometry import shape, Point
11
- import folium
12
- from folium.plugins import Draw
13
- from io import BytesIO
14
- import base64
15
- import json
16
- import os
17
- from PIL import Image
18
- import ee
19
- from datetime import datetime, timedelta
20
- import rasterio
21
- from rasterio.transform import xy
22
-
23
- # Initialize Earth Engine
24
- try:
25
- ee.Initialize(project='artful-striker-466710-b3')
26
- except Exception as e:
27
- print(f"Error initializing GEE: {str(e)}")
28
- ee.Authenticate()
29
- ee.Initialize(project='artful-striker-466710-b3')
30
-
31
- # Define crop season dictionary
32
- crop_season_dict = {
33
- "Punjab": {
34
- "Rabi": [
35
- "wheat", "barley", "gram (chickpea)", "lentil", "mustard", "rapeseed mustard",
36
- "linseed", "peas", "garlic", "onion", "coriander", "fennel", "potato",
37
- "fallow (agriculture)", "water", "barren", "shrubs", "forest"
38
- ],
39
- "Kharif": [
40
- "cotton", "rice", "sugarcane", "maize", "sesame", "millet", "sorghum", "sunflower",
41
- "groundnuts", "okra", "tomato", "chillies", "banana", "mango",
42
- "fallow (agriculture)", "water", "barren", "shrubs", "forest"
43
- ]
44
- },
45
- "Sindh": {
46
- "Rabi": [
47
- "wheat", "barley", "peas", "gram (chickpea)", "mustard", "onion", "garlic", "spinach",
48
- "coriander", "potato", "fennel", "turnip",
49
- "fallow (agriculture)", "water", "barren", "shrubs", "forest"
50
- ],
51
- "Kharif": [
52
- "cotton", "rice", "sugarcane", "maize", "sesame", "millet", "okra", "tomato",
53
- "chillies", "banana", "mango", "sunflower", "guava",
54
- "fallow (agriculture)", "water", "barren", "shrubs", "forest"
55
- ]
56
- },
57
- "Balochistan": {
58
- "Rabi": [
59
- "wheat", "barley", "gram (chickpea)", "lentil", "peas", "mustard", "potato",
60
- "onion", "coriander", "fallow (agriculture)", "water", "barren", "shrubs", "forest"
61
- ],
62
- "Kharif": [
63
- "maize", "rice", "millet", "sorghum", "peach", "apple", "grapes", "tomato",
64
- "chillies", "pomegranate", "groundnuts", "sunflower",
65
- "fallow (agriculture)", "water", "barren", "shrubs", "forest"
66
- ]
67
- },
68
- "Khyber Pakhtunkhwa": {
69
- "Rabi": [
70
- "wheat", "barley", "gram (chickpea)", "lentil", "peas", "mustard", "onion",
71
- "garlic", "turnip", "potato", "coriander",
72
- "fallow (agriculture)", "water", "barren", "shrubs", "forest"
73
- ],
74
- "Kharif": [
75
- "maize", "rice", "sugarcane", "tomato", "chillies", "peach", "plum", "apricot",
76
- "apple", "mango", "sunflower", "okra", "sesame",
77
- "fallow (agriculture)", "water", "barren", "shrubs", "forest"
78
- ]
79
- }
80
- }
81
-
82
- # Define model
83
- class CropClassifier(nn.Module):
84
- def __init__(self, input_size, num_classes):
85
- super(CropClassifier, self).__init__()
86
- self.network = nn.Sequential(
87
- nn.Linear(input_size, 512),
88
- nn.BatchNorm1d(512),
89
- nn.LeakyReLU(),
90
- nn.Dropout(0.4),
91
- nn.Linear(512, 256),
92
- nn.BatchNorm1d(256),
93
- nn.LeakyReLU(),
94
- nn.Dropout(0.3),
95
- nn.Linear(256, 128),
96
- nn.BatchNorm1d(128),
97
- nn.LeakyReLU(),
98
- nn.Dropout(0.2),
99
- nn.Linear(128, 64),
100
- nn.BatchNorm1d(64),
101
- nn.LeakyReLU(),
102
- nn.Dropout(0.1),
103
- nn.Linear(64, num_classes)
104
- )
105
- def forward(self, x):
106
- return self.network(x)
107
-
108
- # Load saved objects
109
- scaler = joblib.load("scaler.pkl")
110
- label_to_idx = joblib.load("label_encoder.pkl")
111
- feature_columns = joblib.load("feature_columns.pkl")
112
- idx_to_label = {v: k for k, v in label_to_idx.items()}
113
-
114
- # Load model
115
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
116
- model = CropClassifier(len(feature_columns), len(label_to_idx)).to(device)
117
- model.load_state_dict(torch.load("final_crop_model.pth", map_location=device))
118
- model.eval()
119
-
120
- # Uncertainty threshold
121
- uncertainty_threshold = 0.2
122
- uncertain_class_idx = len(label_to_idx)
123
- idx_to_label[uncertain_class_idx] = "Uncertain"
124
-
125
- # Global variable to store current polygon
126
- current_polygon_data = None
127
-
128
- def get_color_palette(n):
129
- if n <= 20:
130
- palette = list(matplotlib.colors.TABLEAU_COLORS.values()) + list(matplotlib.colors.CSS4_COLORS.values())
131
- return palette[:n]
132
- else:
133
- return [matplotlib.colors.rgb2hex(matplotlib.cm.hsv(i/n)) for i in range(n)]
134
-
135
- def assign_crop_colors(unique_crops):
136
- palette = get_color_palette(len(unique_crops))
137
- return {crop: palette[i] for i, crop in enumerate(unique_crops)}
138
-
139
- def get_valid_user_classes(province, season):
140
- """Fetch valid classes based on province and season from crop_season_dict."""
141
- try:
142
- user_classes = crop_season_dict.get(province, {}).get(season, [])
143
- return [cls for cls in user_classes if cls in label_to_idx]
144
- except:
145
- return []
146
-
147
- # --- Upload Processing Function ---
148
- def process_upload(file, province, season, date):
149
- if file is None:
150
- return "No file uploaded. Please upload a .tiff or .tif file.", None
151
-
152
- if not file.name.endswith(('.tiff', '.tif')):
153
- return "Unsupported file format. Please upload a .tiff or .tif file.", None
154
-
155
- # Load GeoTIFF file
156
- try:
157
- with rasterio.open(file) as src:
158
- patch = src.read() # Shape: (bands, height, width)
159
- transform = src.transform
160
- rows, cols = patch.shape[1], patch.shape[2]
161
- row_indices, col_indices = np.meshgrid(np.arange(rows), np.arange(cols), indexing='ij')
162
- lon, lat = xy(transform, row_indices, col_indices)
163
- # Convert lon, lat to 2D arrays (shape: [rows, cols])
164
- lon_mask = np.array(lon).reshape(rows, cols)
165
- lat_mask = np.array(lat).reshape(rows, cols)
166
- except Exception as e:
167
- return f"Error reading GeoTIFF file: {str(e)}", None
168
-
169
- # Validate the number of bands
170
- if len(patch.shape) != 3 or patch.shape[0] < 7:
171
- return "Invalid GeoTIFF file format. Expected at least 7 bands [r, g, b, rededge, nir, swr1, swr2].", None
172
-
173
- # # Resize patch to 500x500 if necessary
174
- # patch = patch[:, :500, :500]
175
- patch = np.transpose(patch, (1, 2, 0)) # Shape: (H, W, 7)
176
- H, W, _ = patch.shape
177
-
178
- # Extract RGB for visualization
179
- r, g, b = patch[..., 0], patch[..., 1], patch[..., 2]
180
- rgb = np.stack([r, g, b], axis=-1).astype(np.float32)
181
- rgb_norm = (rgb - rgb.min()) / (rgb.max() - rgb.min() + 1e-6)
182
-
183
-
184
- # Process pixels for prediction
185
- pixels = []
186
- for i in range(H):
187
- for j in range(W):
188
- pix = patch[i, j].astype(np.float32)
189
- red, green, blue, nir, swr1 = pix[0], pix[1], pix[2], pix[4], pix[5]
190
- pixels.append({
191
- "Province": province,
192
- "Season": season,
193
- "Latitude": lat_mask[i, j],
194
- "Longitude": lon_mask[i, j],
195
- "NDVI": (nir - red) / (nir + red + 1e-6),
196
- "NDWI": (green - nir) / (green + nir + 1e-6),
197
- "NDBI": (swr1 - nir) / (swr1 + nir + 1e-6),
198
- "Red": red,
199
- "Green": green,
200
- "Blue": blue,
201
- "NIR": nir,
202
- "SWIR": swr1,
203
- "Date": date
204
- })
205
-
206
- # Create DataFrame and preprocess
207
- df = pd.DataFrame(pixels)
208
- try:
209
- df["Date"] = pd.to_datetime(df["Date"], dayfirst=True)
210
- except:
211
- return "Invalid date format. Please use DD/MM/YYYY.", None
212
- df["HalfMonth"] = df["Date"].dt.day.apply(lambda x: 0 if x <= 15 else 1)
213
- df["Month"] = df["Date"].dt.month
214
- df.drop(columns=["Date"], inplace=True)
215
-
216
- # Dummy encoding and feature alignment
217
- df = pd.get_dummies(df, columns=['Province', 'Season'], dummy_na=True)
218
- missing_cols = set(feature_columns) - set(df.columns)
219
- for col in missing_cols:
220
- df[col] = 0
221
- df = df[feature_columns]
222
- df = df.replace([np.inf, -np.inf], np.finfo(np.float32).eps)
223
-
224
- # Model prediction
225
- try:
226
- X_scaled = scaler.transform(df)
227
- except Exception as e:
228
- return f"Error scaling features: {str(e)}", None
229
- X_tensor = torch.tensor(X_scaled, dtype=torch.float32).to(device)
230
- with torch.no_grad():
231
- outputs = model(X_tensor)
232
- valid_user_classes = get_valid_user_classes(province, season)
233
- user_class_indices = [label_to_idx[cls] for cls in valid_user_classes if cls in label_to_idx]
234
- if user_class_indices:
235
- mask = torch.ones_like(outputs) * -1e10
236
- for idx in user_class_indices:
237
- mask[:, idx] = 0
238
- outputs = outputs + mask
239
- probs = torch.softmax(outputs, dim=1)
240
- max_probs, preds = torch.max(probs, dim=1)
241
- uncertain_mask = max_probs < uncertainty_threshold
242
- preds[uncertain_mask] = uncertain_class_idx
243
- preds = preds.cpu().numpy().reshape(H, W)
244
-
245
- # Create visualization
246
- unique_classes = np.unique(preds)
247
- color_map = assign_crop_colors([idx_to_label[cls] for cls in unique_classes])
248
- mask_img = np.zeros((H, W, 3))
249
- for cls, color in color_map.items():
250
- mask_img[preds == label_to_idx.get(cls, uncertain_class_idx)] = matplotlib.colors.to_rgb(color)
251
-
252
- fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
253
- ax1.imshow(rgb_norm)
254
- ax1.set_title("Original RGB Patch")
255
- ax1.axis("off")
256
- ax2.imshow(mask_img)
257
- ax2.set_title("Predicted Crop Classification")
258
- ax2.axis("off")
259
- legend_elements = [Patch(facecolor=color_map[idx_to_label[cls]], edgecolor='black', label=idx_to_label[cls]) for cls in unique_classes]
260
- fig.legend(handles=legend_elements, loc='center right', bbox_to_anchor=(1.15, 0.5), title="Predicted Crops")
261
- plt.tight_layout()
262
-
263
- buf = BytesIO()
264
- plt.savefig(buf, format="png", bbox_inches="tight")
265
- plt.close()
266
- buf.seek(0)
267
- image = Image.open(buf)
268
-
269
- # Generate prediction statistics
270
- stats = "Prediction Statistics:\n"
271
- for cls in unique_classes:
272
- class_name = idx_to_label[cls]
273
- pixel_count = np.sum(preds == cls)
274
- percentage = (pixel_count / (H * W)) * 100
275
- stats += f"{class_name}: {pixel_count} pixels ({percentage:.2f}%)\n"
276
-
277
- return stats, image
278
-
279
- # --- Map Interface ---
280
- def generate_grid_points(polygon, spacing_deg):
281
- min_lon, min_lat, max_lon, max_lat = polygon.bounds
282
- grid_points = []
283
- point_id = 1
284
- lat_step = spacing_deg / 2
285
- lon_step = spacing_deg / 2
286
- lat = min_lat
287
- while lat <= max_lat:
288
- lon = min_lon
289
- while lon <= max_lon:
290
- pt = Point(lon, lat)
291
- if polygon.contains(pt):
292
- is_spaced = True
293
- for existing_pt in grid_points:
294
- dist = ((existing_pt["latitude"] - lat) ** 2 + (existing_pt["longitude"] - lon) ** 2) ** 0.5
295
- if dist < spacing_deg:
296
- is_spaced = False
297
- break
298
- if is_spaced:
299
- grid_points.append({
300
- "point_id": point_id,
301
- "latitude": round(lat, 6),
302
- "longitude": round(lon, 6)
303
- })
304
- point_id += 1
305
- lon += lon_step
306
- lat += lat_step
307
- return grid_points
308
-
309
- def get_indices(lat, lon, date_str):
310
- try:
311
- point = ee.Geometry.Point([lon, lat])
312
- date = datetime.strptime(date_str, "%d/%m/%Y")
313
- start = ee.Date(date.strftime('%Y-%m-%d'))
314
- end = ee.Date((date + timedelta(days=30)).strftime('%Y-%m-%d'))
315
-
316
- collection = (ee.ImageCollection("COPERNICUS/S2_SR_HARMONIZED")
317
- .filterBounds(point)
318
- .filterDate(start, end)
319
- .filter(ee.Filter.lt('CLOUDY_PIXEL_PERCENTAGE', 10)))
320
-
321
- image = collection.median().clip(point)
322
-
323
- band_names = image.bandNames().getInfo()
324
- if not band_names:
325
- return None
326
-
327
- B2 = image.select('B2') # Blue
328
- B3 = image.select('B3') # Green
329
- B4 = image.select('B4') # Red
330
- B8 = image.select('B8') # NIR
331
- B11 = image.select('B11') # SWIR
332
-
333
- ndvi = image.normalizedDifference(['B8', 'B4']).rename('NDVI')
334
- ndwi = image.normalizedDifference(['B3', 'B8']).rename('NDWI')
335
- evi = image.expression(
336
- '2.5 * ((NIR - RED) / (NIR + 6 * RED - 7.5 * BLUE + 1))',
337
- {'NIR': B8, 'RED': B4, 'BLUE': B2}).rename('EVI')
338
- gndvi = image.normalizedDifference(['B8', 'B3']).rename('GNDVI')
339
- savi = image.expression(
340
- '((NIR - RED) / (NIR + RED + 0.5)) * 1.5',
341
- {'NIR': B8, 'RED': B4}).rename('SAVI')
342
-
343
- all_bands = image.addBands([ndvi, ndwi, evi, gndvi, savi])
344
-
345
- values = all_bands.reduceRegion(
346
- reducer=ee.Reducer.first(),
347
- geometry=point,
348
- scale=10,
349
- maxPixels=1e8
350
- ).getInfo()
351
-
352
- return {
353
- 'NDVI': values.get('NDVI', 0.0),
354
- 'NDWI': values.get('NDWI', 0.0),
355
- 'EVI': values.get('EVI', 0.0),
356
- 'GNDVI': values.get('GNDVI', 0.0),
357
- 'SAVI': values.get('SAVI', 0.0),
358
- 'Red': values.get('B4', 0.0),
359
- 'Green': values.get('B3', 0.0),
360
- 'Blue': values.get('B2', 0.0),
361
- 'NIR': values.get('B8', 0.0),
362
- 'SWIR': values.get('B11', 0.0)
363
- }
364
- except Exception as e:
365
- print(f"Error fetching indices for lat={lat}, lon={lon}: {str(e)}")
366
- return None
367
-
368
- def predict_crop_description(point, static_features, scaler, feature_columns, province, season):
369
- df = pd.DataFrame([{
370
- **static_features,
371
- "Latitude": point["latitude"],
372
- "Longitude": point["longitude"],
373
- "Date": static_features["Date"]
374
- }])
375
- df["Date"] = pd.to_datetime(df["Date"], dayfirst=True)
376
- df["HalfMonth"] = df["Date"].dt.day.apply(lambda x: 0 if x <= 15 else 1)
377
- df["Month"] = df["Date"].dt.month
378
- df.drop(columns=["Date"], inplace=True)
379
- df = pd.get_dummies(df)
380
- for col in feature_columns:
381
- if col not in df.columns:
382
- df[col] = 0
383
- df = df[feature_columns]
384
- df = df.replace([np.inf, -np.inf], np.finfo(np.float32).eps)
385
- scaled = scaler.transform(df)
386
- X_tensor = torch.tensor(scaled, dtype=torch.float32).to(device)
387
- with torch.no_grad():
388
- outputs = model(X_tensor)
389
- valid_user_classes = get_valid_user_classes(province, season)
390
- user_class_indices = [label_to_idx[cls] for cls in valid_user_classes if cls in label_to_idx]
391
- if user_class_indices:
392
- mask = torch.ones_like(outputs) * -1e10
393
- for idx in user_class_indices:
394
- mask[:, idx] = 0
395
- outputs = outputs + mask
396
- probs = torch.softmax(outputs, dim=1)
397
- max_probs, preds = torch.max(probs, dim=1)
398
- uncertain_mask = max_probs < uncertainty_threshold
399
- preds[uncertain_mask] = uncertain_class_idx
400
- return idx_to_label[preds.cpu().numpy()[0]]
401
-
402
- def create_interactive_map():
403
- m = folium.Map(location=[30.809, 73.45], zoom_start=12)
404
- Draw(
405
- export=True,
406
- filename='polygon.geojson',
407
- draw_options={
408
- "polyline": False,
409
- "rectangle": True,
410
- "circle": True,
411
- "circlemarker": False,
412
- "marker": False,
413
- "polygon": True
414
- }
415
- ).add_to(m)
416
- return m._repr_html_()
417
-
418
- def select_polygon(geojson_file):
419
- global current_polygon_data
420
- if not geojson_file:
421
- return "❌ No GeoJSON file uploaded. Please draw a polygon, export it, and upload the file."
422
-
423
- try:
424
- with open(geojson_file.name, 'r') as f:
425
- geojson_data = json.load(f)
426
-
427
- if geojson_data.get('type') == 'FeatureCollection':
428
- features = geojson_data.get('features', [])
429
- for feature in features:
430
- if feature.get('geometry', {}).get('type') == 'Polygon':
431
- current_polygon_data = feature
432
- return "βœ… Polygon selected successfully!"
433
- return "❌ No valid polygon found in the GeoJSON file."
434
- except Exception as e:
435
- return f"Error reading GeoJSON file: {str(e)}"
436
-
437
- def process_polygon_prediction(spacing_m, province, season, date, geojson_file):
438
- global current_polygon_data
439
-
440
- try:
441
- datetime.strptime(date, "%d/%m/%Y")
442
- except ValueError:
443
- return "Invalid date format. Please use DD/MM/YYYY.", None, None
444
-
445
- if not current_polygon_data:
446
- return "❌ No polygon selected. Please draw a polygon, export it as GeoJSON, and upload it.", None, None
447
-
448
- try:
449
- polygon = shape(current_polygon_data['geometry'])
450
- except Exception as e:
451
- return f"Error parsing polygon: {str(e)}", None, None
452
-
453
- spacing_deg = spacing_m / 111320.0
454
- points = generate_grid_points(polygon, spacing_deg)
455
- print(f"Number of points selected: {len(points)}")
456
-
457
- if not points:
458
- return "No points generated within the polygon. Try increasing the spacing.", None, None
459
-
460
- predicted_points = []
461
- static_features = {
462
- "Province": province,
463
- "Season": season,
464
- "Date": date
465
- }
466
-
467
- for i, point in enumerate(points, 1):
468
- indices = get_indices(point["latitude"], point["longitude"], date)
469
- print(f"GEE started for point {i} at lat={point['latitude']}, lon={point['longitude']}")
470
- if indices:
471
- print(f"GEE values fetched for point {i}")
472
- static_features.update({
473
- "NDVI": indices["NDVI"],
474
- "NDWI": indices["NDWI"],
475
- "EVI": indices["EVI"],
476
- "GNDVI": indices["GNDVI"],
477
- "SAVI": indices["SAVI"],
478
- "Red": indices["Red"],
479
- "Green": indices["Green"],
480
- "Blue": indices["Blue"],
481
- "NIR": indices["NIR"],
482
- "SWIR": indices["SWIR"]
483
- })
484
- crop = predict_crop_description(point, static_features, scaler, feature_columns, province, season)
485
- point.update({
486
- "crop": crop,
487
- "NDVI": indices["NDVI"],
488
- "NDWI": indices["NDWI"],
489
- "EVI": indices["EVI"],
490
- "GNDVI": indices["GNDVI"],
491
- "SAVI": indices["SAVI"]
492
- })
493
- predicted_points.append(point)
494
-
495
- if not predicted_points:
496
- return "No valid data found for any grid points.", None, None
497
-
498
- pred_df = pd.DataFrame(predicted_points)
499
- unique_crops = pred_df['crop'].unique()
500
- crop_colors = assign_crop_colors(unique_crops)
501
-
502
- center_lat = sum(pt["latitude"] for pt in predicted_points) / len(predicted_points)
503
- center_lon = sum(pt["longitude"] for pt in predicted_points) / len(predicted_points)
504
- pred_map = folium.Map(location=[center_lat, center_lon], zoom_start=12)
505
-
506
- folium.GeoJson(
507
- current_polygon_data,
508
- style_function=lambda x: {'color': 'red', 'weight': 3, 'fill': False}
509
- ).add_to(pred_map)
510
-
511
- for pt in predicted_points:
512
- crop_type = pt.get("crop", "Other")
513
- color = crop_colors.get(crop_type, "#808080")
514
- folium.Circle(
515
- location=[pt["latitude"], pt["longitude"]],
516
- radius=spacing_m/2,
517
- color='black',
518
- weight=1,
519
- fill=True,
520
- fillColor=color,
521
- fillOpacity=0.7,
522
- popup=f"Crop: {crop_type}<br>Lat: {pt['latitude']:.4f}<br>Lon: {pt['longitude']:.4f}<br>NDVI: {pt['NDVI']:.3f}<br>NDWI: {pt['NDWI']:.3f}<br>EVI: {pt['EVI']:.3f}<br>GNDVI: {pt['GNDVI']:.3f}<br>SAVI: {pt['SAVI']:.3f}",
523
- tooltip=crop_type
524
- ).add_to(pred_map)
525
-
526
- legend_html = '''
527
- <div style="position: fixed; bottom: 50px; left: 50px; width: 180px;
528
- background-color: white; border:2px solid grey; z-index:9999;
529
- font-size:14px; padding: 10px; border-radius: 5px;">
530
- <p style="margin: 0 0 10px 0; font-weight:bold;">🌾 Crop Types</p>
531
- '''
532
- for crop in unique_crops:
533
- color = crop_colors[crop]
534
- legend_html += f'<p style="margin: 5px 0;"><span style="color:{color}; font-size:16px;">●</span> {crop}</p>'
535
- legend_html += '</div>'
536
- pred_map.get_root().html.add_child(folium.Element(legend_html))
537
-
538
- crop_stats = pred_df['crop'].value_counts()
539
- stats = f"βœ… Polygon processed successfully!\n\nCrop Distribution (Province: {province}, Season: {season}):\n"
540
- for crop, count in crop_stats.items():
541
- percentage = (count / len(predicted_points)) * 100
542
- stats += f"{crop}: {count} points ({percentage:.1f}%)\n"
543
- for index in ['NDVI', 'NDWI', 'EVI', 'GNDVI', 'SAVI']:
544
- avg = pred_df[index].mean()
545
- stats += f"Average {index}: {avg:.3f}\n"
546
-
547
- csv_file_path = f"crop_predictions_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
548
- try:
549
- pred_df.to_csv(csv_file_path, index=False)
550
- except Exception as e:
551
- print(f"Error creating CSV file: {str(e)}")
552
- csv_file_path = None
553
-
554
- return stats, pred_map._repr_html_(), csv_file_path
555
-
556
- # --- Instance Interface ---
557
- def predict_instance(province, season, latitude, longitude, date, ndvi, ndwi, ndbi, red, green, blue, nir, swir):
558
- static_features = {
559
- "Province": province,
560
- "Season": season,
561
- "NDVI": ndvi,
562
- "NDWI": ndwi,
563
- "NDBI": ndbi,
564
- "Red": red,
565
- "Green": green,
566
- "Blue": blue,
567
- "NIR": nir,
568
- "SWIR": swir,
569
- "Date": date
570
- }
571
- crop = predict_crop_description({"latitude": latitude, "longitude": longitude}, static_features, scaler, feature_columns, province, season)
572
- return f"{crop}"
573
-
574
- from pathlib import Path
575
- import gradio as gr
576
-
577
- # Sample file paths
578
- sample_dir = Path("samples") # Ensure this directory exists with .tif files
579
- sample_files = {
580
- "Sample 1": sample_dir / "sample1.tif",
581
- "Sample 2": sample_dir / "sample2.tif"
582
- }
583
-
584
- # Function to simulate upload when sample is clicked
585
- def load_sample_and_predict(sample_name, province, season, date):
586
- file_path = sample_files[sample_name]
587
- return process_upload(file_path, province, season, date)
588
-
589
- # --- Gradio Interface ---
590
- with gr.Blocks(title="Crop Predictor", theme=gr.themes.Soft()) as demo:
591
- gr.Markdown("# 🌾 Crop Predictor")
592
-
593
- with gr.Tabs():
594
- with gr.TabItem("πŸ“€ Upload"):
595
- gr.Markdown("Upload a .tiff or .tif file with bands [r, g, b, rededge, nir, swr1, swr2]")
596
-
597
- file_input = gr.File(label="Upload .tiff/.tif file", file_types=[".tiff", ".tif"])
598
-
599
- with gr.Row():
600
- province = gr.Textbox(label="Province", value="Punjab")
601
- season = gr.Textbox(label="Season", value="Rabi")
602
-
603
- with gr.Row():
604
- date = gr.Textbox(label="Date (DD/MM/YYYY)", value="10/01/2023")
605
-
606
- upload_btn = gr.Button("πŸ” Predict", variant="primary")
607
- output_stats = gr.Textbox(label="Prediction Statistics", lines=10)
608
- output_image = gr.Image(label="Prediction Result")
609
-
610
- upload_btn.click(
611
- fn=process_upload,
612
- inputs=[file_input, province, season, date],
613
- outputs=[output_stats, output_image]
614
- )
615
-
616
- # -- Add Sample File Buttons Here --
617
- gr.Markdown("### Or try with a sample file:")
618
- with gr.Row():
619
- for name in sample_files:
620
- gr.Button(name).click(
621
- fn=load_sample_and_predict,
622
- inputs=[gr.State(name), province, season, date],
623
- outputs=[output_stats, output_image]
624
- )
625
-
626
- with gr.TabItem("πŸ—ΊοΈ Map"):
627
- gr.Markdown("""
628
- ## Interactive Polygon Crop Prediction
629
-
630
- **Instructions:**
631
- 1. Draw a polygon on the map below using the polygon tool.
632
- 2. Click the "Export" button on the map to save the polygon as a GeoJSON file (polygon.geojson).
633
- 3. Upload the exported GeoJSON file using the file input below.
634
- 4. Adjust settings and click "πŸ” Predict" to process.
635
- """)
636
-
637
- map_html = gr.HTML(create_interactive_map, label="Draw Your Polygon Here")
638
-
639
- with gr.Row():
640
- geojson_input = gr.File(label="Upload Exported GeoJSON File")
641
- select_btn = gr.Button("🎯 Select My Polygon", variant="secondary")
642
- spacing = gr.Slider(
643
- label="Grid Spacing (meters)",
644
- minimum=10, maximum=1000, value=30, step=100
645
- )
646
-
647
- with gr.Row():
648
- province_map = gr.Textbox(label="Province", value="Punjab")
649
- season_map = gr.Textbox(label="Season", value="Multan")
650
- date_map = gr.Textbox(label="Date (DD/MM/YYYY)", value="10/01/2023")
651
-
652
- polygon_status = gr.Textbox(
653
- label="Selection Status",
654
- value="⏳ Please draw a polygon, export it, and upload the GeoJSON file.",
655
- interactive=False
656
- )
657
-
658
- predict_btn = gr.Button("πŸ” Predict Crops", variant="primary", size="lg")
659
-
660
- output_map_stats = gr.Textbox(label="Prediction Results", lines=10)
661
- output_map = gr.HTML(label="Crop Prediction Map")
662
- output_csv = gr.File(label="πŸ“₯ Download Results CSV")
663
-
664
- select_btn.click(
665
- fn=select_polygon,
666
- inputs=[geojson_input],
667
- outputs=polygon_status
668
- )
669
-
670
- predict_btn.click(
671
- fn=process_polygon_prediction,
672
- inputs=[spacing, province_map, season_map, date_map, geojson_input],
673
- outputs=[output_map_stats, output_map, output_csv]
674
- )
675
-
676
- with gr.TabItem("πŸ“Š Instance"):
677
- gr.Markdown("## Single Point Prediction")
678
- gr.Markdown("Enter features manually for a single point prediction")
679
-
680
- with gr.Row():
681
- province_inst = gr.Textbox(label="Province", value="Punjab")
682
- season_inst = gr.Textbox(label="Season", value="Rabi")
683
-
684
- with gr.Row():
685
- latitude_inst = gr.Number(label="Latitude", value=30.809)
686
- longitude_inst = gr.Number(label="Longitude", value=73.450)
687
- date_inst = gr.Textbox(label="Date (DD/MM/YYYY)", value="10/01/2023")
688
-
689
- gr.Markdown("### Spectral Indices")
690
- with gr.Row():
691
- ndvi_inst = gr.Number(label="NDVI", value=0.65)
692
- ndwi_inst = gr.Number(label="NDWI", value=-2.0)
693
- ndbi_inst = gr.Number(label="NDBI", value=0.10)
694
-
695
- gr.Markdown("### Band Values")
696
- with gr.Row():
697
- red_inst = gr.Number(label="Red", value=678)
698
- green_inst = gr.Number(label="Green", value=732)
699
- blue_inst = gr.Number(label="Blue", value=620)
700
-
701
- with gr.Row():
702
- nir_inst = gr.Number(label="NIR", value=3000)
703
- swir_inst = gr.Number(label="SWIR", value=1800)
704
-
705
- instance_btn = gr.Button("πŸ” Predict", variant="primary")
706
- output_instance = gr.Textbox(label="Prediction Result", lines=3)
707
-
708
- instance_btn.click(
709
- fn=predict_instance,
710
- inputs=[province_inst, season_inst, latitude_inst, longitude_inst,
711
- date_inst, ndvi_inst, ndwi_inst, ndbi_inst, red_inst,
712
- green_inst, blue_inst, nir_inst, swir_inst],
713
- outputs=output_instance
714
- )
715
-
716
- if __name__ == "__main__":
 
 
717
  demo.launch(share=True)
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import pandas as pd
4
+ import torch
5
+ import torch.nn as nn
6
+ import joblib
7
+ import matplotlib.pyplot as plt
8
+ from matplotlib.patches import Patch
9
+ import matplotlib
10
+ from shapely.geometry import shape, Point
11
+ import folium
12
+ from folium.plugins import Draw
13
+ from io import BytesIO
14
+ import base64
15
+ import json
16
+ import os
17
+ from PIL import Image
18
+ import ee
19
+ from datetime import datetime, timedelta
20
+ import rasterio
21
+ from rasterio.transform import xy
22
+
23
+ # Absolute path to your credentials file (update this if needed)
24
+ CREDENTIALS_PATH = "credentials" # Windows
25
+
26
+ try:
27
+ # Use OAuth credentials from file
28
+ credentials = ee.OAuthCredentials(None, CREDENTIALS_PATH)
29
+ ee.Initialize(credentials, project='your-project-id') # Replace with your actual project ID
30
+ except Exception as e:
31
+ print(f"Error initializing GEE: {str(e)}")
32
+
33
+ # Define crop season dictionary
34
+ crop_season_dict = {
35
+ "Punjab": {
36
+ "Rabi": [
37
+ "wheat", "barley", "gram (chickpea)", "lentil", "mustard", "rapeseed mustard",
38
+ "linseed", "peas", "garlic", "onion", "coriander", "fennel", "potato",
39
+ "fallow (agriculture)", "water", "barren", "shrubs", "forest"
40
+ ],
41
+ "Kharif": [
42
+ "cotton", "rice", "sugarcane", "maize", "sesame", "millet", "sorghum", "sunflower",
43
+ "groundnuts", "okra", "tomato", "chillies", "banana", "mango",
44
+ "fallow (agriculture)", "water", "barren", "shrubs", "forest"
45
+ ]
46
+ },
47
+ "Sindh": {
48
+ "Rabi": [
49
+ "wheat", "barley", "peas", "gram (chickpea)", "mustard", "onion", "garlic", "spinach",
50
+ "coriander", "potato", "fennel", "turnip",
51
+ "fallow (agriculture)", "water", "barren", "shrubs", "forest"
52
+ ],
53
+ "Kharif": [
54
+ "cotton", "rice", "sugarcane", "maize", "sesame", "millet", "okra", "tomato",
55
+ "chillies", "banana", "mango", "sunflower", "guava",
56
+ "fallow (agriculture)", "water", "barren", "shrubs", "forest"
57
+ ]
58
+ },
59
+ "Balochistan": {
60
+ "Rabi": [
61
+ "wheat", "barley", "gram (chickpea)", "lentil", "peas", "mustard", "potato",
62
+ "onion", "coriander", "fallow (agriculture)", "water", "barren", "shrubs", "forest"
63
+ ],
64
+ "Kharif": [
65
+ "maize", "rice", "millet", "sorghum", "peach", "apple", "grapes", "tomato",
66
+ "chillies", "pomegranate", "groundnuts", "sunflower",
67
+ "fallow (agriculture)", "water", "barren", "shrubs", "forest"
68
+ ]
69
+ },
70
+ "Khyber Pakhtunkhwa": {
71
+ "Rabi": [
72
+ "wheat", "barley", "gram (chickpea)", "lentil", "peas", "mustard", "onion",
73
+ "garlic", "turnip", "potato", "coriander",
74
+ "fallow (agriculture)", "water", "barren", "shrubs", "forest"
75
+ ],
76
+ "Kharif": [
77
+ "maize", "rice", "sugarcane", "tomato", "chillies", "peach", "plum", "apricot",
78
+ "apple", "mango", "sunflower", "okra", "sesame",
79
+ "fallow (agriculture)", "water", "barren", "shrubs", "forest"
80
+ ]
81
+ }
82
+ }
83
+
84
+ # Define model
85
+ class CropClassifier(nn.Module):
86
+ def __init__(self, input_size, num_classes):
87
+ super(CropClassifier, self).__init__()
88
+ self.network = nn.Sequential(
89
+ nn.Linear(input_size, 512),
90
+ nn.BatchNorm1d(512),
91
+ nn.LeakyReLU(),
92
+ nn.Dropout(0.4),
93
+ nn.Linear(512, 256),
94
+ nn.BatchNorm1d(256),
95
+ nn.LeakyReLU(),
96
+ nn.Dropout(0.3),
97
+ nn.Linear(256, 128),
98
+ nn.BatchNorm1d(128),
99
+ nn.LeakyReLU(),
100
+ nn.Dropout(0.2),
101
+ nn.Linear(128, 64),
102
+ nn.BatchNorm1d(64),
103
+ nn.LeakyReLU(),
104
+ nn.Dropout(0.1),
105
+ nn.Linear(64, num_classes)
106
+ )
107
+ def forward(self, x):
108
+ return self.network(x)
109
+
110
+ # Load saved objects
111
+ scaler = joblib.load("scaler.pkl")
112
+ label_to_idx = joblib.load("label_encoder.pkl")
113
+ feature_columns = joblib.load("feature_columns.pkl")
114
+ idx_to_label = {v: k for k, v in label_to_idx.items()}
115
+
116
+ # Load model
117
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
118
+ model = CropClassifier(len(feature_columns), len(label_to_idx)).to(device)
119
+ model.load_state_dict(torch.load("final_crop_model.pth", map_location=device))
120
+ model.eval()
121
+
122
+ # Uncertainty threshold
123
+ uncertainty_threshold = 0.2
124
+ uncertain_class_idx = len(label_to_idx)
125
+ idx_to_label[uncertain_class_idx] = "Uncertain"
126
+
127
+ # Global variable to store current polygon
128
+ current_polygon_data = None
129
+
130
+ def get_color_palette(n):
131
+ if n <= 20:
132
+ palette = list(matplotlib.colors.TABLEAU_COLORS.values()) + list(matplotlib.colors.CSS4_COLORS.values())
133
+ return palette[:n]
134
+ else:
135
+ return [matplotlib.colors.rgb2hex(matplotlib.cm.hsv(i/n)) for i in range(n)]
136
+
137
+ def assign_crop_colors(unique_crops):
138
+ palette = get_color_palette(len(unique_crops))
139
+ return {crop: palette[i] for i, crop in enumerate(unique_crops)}
140
+
141
+ def get_valid_user_classes(province, season):
142
+ """Fetch valid classes based on province and season from crop_season_dict."""
143
+ try:
144
+ user_classes = crop_season_dict.get(province, {}).get(season, [])
145
+ return [cls for cls in user_classes if cls in label_to_idx]
146
+ except:
147
+ return []
148
+
149
+ # --- Upload Processing Function ---
150
+ def process_upload(file, province, season, date):
151
+ if file is None:
152
+ return "No file uploaded. Please upload a .tiff or .tif file.", None
153
+
154
+ if not file.name.endswith(('.tiff', '.tif')):
155
+ return "Unsupported file format. Please upload a .tiff or .tif file.", None
156
+
157
+ # Load GeoTIFF file
158
+ try:
159
+ with rasterio.open(file) as src:
160
+ patch = src.read() # Shape: (bands, height, width)
161
+ transform = src.transform
162
+ rows, cols = patch.shape[1], patch.shape[2]
163
+ row_indices, col_indices = np.meshgrid(np.arange(rows), np.arange(cols), indexing='ij')
164
+ lon, lat = xy(transform, row_indices, col_indices)
165
+ # Convert lon, lat to 2D arrays (shape: [rows, cols])
166
+ lon_mask = np.array(lon).reshape(rows, cols)
167
+ lat_mask = np.array(lat).reshape(rows, cols)
168
+ except Exception as e:
169
+ return f"Error reading GeoTIFF file: {str(e)}", None
170
+
171
+ # Validate the number of bands
172
+ if len(patch.shape) != 3 or patch.shape[0] < 7:
173
+ return "Invalid GeoTIFF file format. Expected at least 7 bands [r, g, b, rededge, nir, swr1, swr2].", None
174
+
175
+ # # Resize patch to 500x500 if necessary
176
+ # patch = patch[:, :500, :500]
177
+ patch = np.transpose(patch, (1, 2, 0)) # Shape: (H, W, 7)
178
+ H, W, _ = patch.shape
179
+
180
+ # Extract RGB for visualization
181
+ r, g, b = patch[..., 0], patch[..., 1], patch[..., 2]
182
+ rgb = np.stack([r, g, b], axis=-1).astype(np.float32)
183
+ rgb_norm = (rgb - rgb.min()) / (rgb.max() - rgb.min() + 1e-6)
184
+
185
+
186
+ # Process pixels for prediction
187
+ pixels = []
188
+ for i in range(H):
189
+ for j in range(W):
190
+ pix = patch[i, j].astype(np.float32)
191
+ red, green, blue, nir, swr1 = pix[0], pix[1], pix[2], pix[4], pix[5]
192
+ pixels.append({
193
+ "Province": province,
194
+ "Season": season,
195
+ "Latitude": lat_mask[i, j],
196
+ "Longitude": lon_mask[i, j],
197
+ "NDVI": (nir - red) / (nir + red + 1e-6),
198
+ "NDWI": (green - nir) / (green + nir + 1e-6),
199
+ "NDBI": (swr1 - nir) / (swr1 + nir + 1e-6),
200
+ "Red": red,
201
+ "Green": green,
202
+ "Blue": blue,
203
+ "NIR": nir,
204
+ "SWIR": swr1,
205
+ "Date": date
206
+ })
207
+
208
+ # Create DataFrame and preprocess
209
+ df = pd.DataFrame(pixels)
210
+ try:
211
+ df["Date"] = pd.to_datetime(df["Date"], dayfirst=True)
212
+ except:
213
+ return "Invalid date format. Please use DD/MM/YYYY.", None
214
+ df["HalfMonth"] = df["Date"].dt.day.apply(lambda x: 0 if x <= 15 else 1)
215
+ df["Month"] = df["Date"].dt.month
216
+ df.drop(columns=["Date"], inplace=True)
217
+
218
+ # Dummy encoding and feature alignment
219
+ df = pd.get_dummies(df, columns=['Province', 'Season'], dummy_na=True)
220
+ missing_cols = set(feature_columns) - set(df.columns)
221
+ for col in missing_cols:
222
+ df[col] = 0
223
+ df = df[feature_columns]
224
+ df = df.replace([np.inf, -np.inf], np.finfo(np.float32).eps)
225
+
226
+ # Model prediction
227
+ try:
228
+ X_scaled = scaler.transform(df)
229
+ except Exception as e:
230
+ return f"Error scaling features: {str(e)}", None
231
+ X_tensor = torch.tensor(X_scaled, dtype=torch.float32).to(device)
232
+ with torch.no_grad():
233
+ outputs = model(X_tensor)
234
+ valid_user_classes = get_valid_user_classes(province, season)
235
+ user_class_indices = [label_to_idx[cls] for cls in valid_user_classes if cls in label_to_idx]
236
+ if user_class_indices:
237
+ mask = torch.ones_like(outputs) * -1e10
238
+ for idx in user_class_indices:
239
+ mask[:, idx] = 0
240
+ outputs = outputs + mask
241
+ probs = torch.softmax(outputs, dim=1)
242
+ max_probs, preds = torch.max(probs, dim=1)
243
+ uncertain_mask = max_probs < uncertainty_threshold
244
+ preds[uncertain_mask] = uncertain_class_idx
245
+ preds = preds.cpu().numpy().reshape(H, W)
246
+
247
+ # Create visualization
248
+ unique_classes = np.unique(preds)
249
+ color_map = assign_crop_colors([idx_to_label[cls] for cls in unique_classes])
250
+ mask_img = np.zeros((H, W, 3))
251
+ for cls, color in color_map.items():
252
+ mask_img[preds == label_to_idx.get(cls, uncertain_class_idx)] = matplotlib.colors.to_rgb(color)
253
+
254
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
255
+ ax1.imshow(rgb_norm)
256
+ ax1.set_title("Original RGB Patch")
257
+ ax1.axis("off")
258
+ ax2.imshow(mask_img)
259
+ ax2.set_title("Predicted Crop Classification")
260
+ ax2.axis("off")
261
+ legend_elements = [Patch(facecolor=color_map[idx_to_label[cls]], edgecolor='black', label=idx_to_label[cls]) for cls in unique_classes]
262
+ fig.legend(handles=legend_elements, loc='center right', bbox_to_anchor=(1.15, 0.5), title="Predicted Crops")
263
+ plt.tight_layout()
264
+
265
+ buf = BytesIO()
266
+ plt.savefig(buf, format="png", bbox_inches="tight")
267
+ plt.close()
268
+ buf.seek(0)
269
+ image = Image.open(buf)
270
+
271
+ # Generate prediction statistics
272
+ stats = "Prediction Statistics:\n"
273
+ for cls in unique_classes:
274
+ class_name = idx_to_label[cls]
275
+ pixel_count = np.sum(preds == cls)
276
+ percentage = (pixel_count / (H * W)) * 100
277
+ stats += f"{class_name}: {pixel_count} pixels ({percentage:.2f}%)\n"
278
+
279
+ return stats, image
280
+
281
+ # --- Map Interface ---
282
+ def generate_grid_points(polygon, spacing_deg):
283
+ min_lon, min_lat, max_lon, max_lat = polygon.bounds
284
+ grid_points = []
285
+ point_id = 1
286
+ lat_step = spacing_deg / 2
287
+ lon_step = spacing_deg / 2
288
+ lat = min_lat
289
+ while lat <= max_lat:
290
+ lon = min_lon
291
+ while lon <= max_lon:
292
+ pt = Point(lon, lat)
293
+ if polygon.contains(pt):
294
+ is_spaced = True
295
+ for existing_pt in grid_points:
296
+ dist = ((existing_pt["latitude"] - lat) ** 2 + (existing_pt["longitude"] - lon) ** 2) ** 0.5
297
+ if dist < spacing_deg:
298
+ is_spaced = False
299
+ break
300
+ if is_spaced:
301
+ grid_points.append({
302
+ "point_id": point_id,
303
+ "latitude": round(lat, 6),
304
+ "longitude": round(lon, 6)
305
+ })
306
+ point_id += 1
307
+ lon += lon_step
308
+ lat += lat_step
309
+ return grid_points
310
+
311
+ def get_indices(lat, lon, date_str):
312
+ try:
313
+ point = ee.Geometry.Point([lon, lat])
314
+ date = datetime.strptime(date_str, "%d/%m/%Y")
315
+ start = ee.Date(date.strftime('%Y-%m-%d'))
316
+ end = ee.Date((date + timedelta(days=30)).strftime('%Y-%m-%d'))
317
+
318
+ collection = (ee.ImageCollection("COPERNICUS/S2_SR_HARMONIZED")
319
+ .filterBounds(point)
320
+ .filterDate(start, end)
321
+ .filter(ee.Filter.lt('CLOUDY_PIXEL_PERCENTAGE', 10)))
322
+
323
+ image = collection.median().clip(point)
324
+
325
+ band_names = image.bandNames().getInfo()
326
+ if not band_names:
327
+ return None
328
+
329
+ B2 = image.select('B2') # Blue
330
+ B3 = image.select('B3') # Green
331
+ B4 = image.select('B4') # Red
332
+ B8 = image.select('B8') # NIR
333
+ B11 = image.select('B11') # SWIR
334
+
335
+ ndvi = image.normalizedDifference(['B8', 'B4']).rename('NDVI')
336
+ ndwi = image.normalizedDifference(['B3', 'B8']).rename('NDWI')
337
+ evi = image.expression(
338
+ '2.5 * ((NIR - RED) / (NIR + 6 * RED - 7.5 * BLUE + 1))',
339
+ {'NIR': B8, 'RED': B4, 'BLUE': B2}).rename('EVI')
340
+ gndvi = image.normalizedDifference(['B8', 'B3']).rename('GNDVI')
341
+ savi = image.expression(
342
+ '((NIR - RED) / (NIR + RED + 0.5)) * 1.5',
343
+ {'NIR': B8, 'RED': B4}).rename('SAVI')
344
+
345
+ all_bands = image.addBands([ndvi, ndwi, evi, gndvi, savi])
346
+
347
+ values = all_bands.reduceRegion(
348
+ reducer=ee.Reducer.first(),
349
+ geometry=point,
350
+ scale=10,
351
+ maxPixels=1e8
352
+ ).getInfo()
353
+
354
+ return {
355
+ 'NDVI': values.get('NDVI', 0.0),
356
+ 'NDWI': values.get('NDWI', 0.0),
357
+ 'EVI': values.get('EVI', 0.0),
358
+ 'GNDVI': values.get('GNDVI', 0.0),
359
+ 'SAVI': values.get('SAVI', 0.0),
360
+ 'Red': values.get('B4', 0.0),
361
+ 'Green': values.get('B3', 0.0),
362
+ 'Blue': values.get('B2', 0.0),
363
+ 'NIR': values.get('B8', 0.0),
364
+ 'SWIR': values.get('B11', 0.0)
365
+ }
366
+ except Exception as e:
367
+ print(f"Error fetching indices for lat={lat}, lon={lon}: {str(e)}")
368
+ return None
369
+
370
+ def predict_crop_description(point, static_features, scaler, feature_columns, province, season):
371
+ df = pd.DataFrame([{
372
+ **static_features,
373
+ "Latitude": point["latitude"],
374
+ "Longitude": point["longitude"],
375
+ "Date": static_features["Date"]
376
+ }])
377
+ df["Date"] = pd.to_datetime(df["Date"], dayfirst=True)
378
+ df["HalfMonth"] = df["Date"].dt.day.apply(lambda x: 0 if x <= 15 else 1)
379
+ df["Month"] = df["Date"].dt.month
380
+ df.drop(columns=["Date"], inplace=True)
381
+ df = pd.get_dummies(df)
382
+ for col in feature_columns:
383
+ if col not in df.columns:
384
+ df[col] = 0
385
+ df = df[feature_columns]
386
+ df = df.replace([np.inf, -np.inf], np.finfo(np.float32).eps)
387
+ scaled = scaler.transform(df)
388
+ X_tensor = torch.tensor(scaled, dtype=torch.float32).to(device)
389
+ with torch.no_grad():
390
+ outputs = model(X_tensor)
391
+ valid_user_classes = get_valid_user_classes(province, season)
392
+ user_class_indices = [label_to_idx[cls] for cls in valid_user_classes if cls in label_to_idx]
393
+ if user_class_indices:
394
+ mask = torch.ones_like(outputs) * -1e10
395
+ for idx in user_class_indices:
396
+ mask[:, idx] = 0
397
+ outputs = outputs + mask
398
+ probs = torch.softmax(outputs, dim=1)
399
+ max_probs, preds = torch.max(probs, dim=1)
400
+ uncertain_mask = max_probs < uncertainty_threshold
401
+ preds[uncertain_mask] = uncertain_class_idx
402
+ return idx_to_label[preds.cpu().numpy()[0]]
403
+
404
+ def create_interactive_map():
405
+ m = folium.Map(location=[30.809, 73.45], zoom_start=12)
406
+ Draw(
407
+ export=True,
408
+ filename='polygon.geojson',
409
+ draw_options={
410
+ "polyline": False,
411
+ "rectangle": True,
412
+ "circle": True,
413
+ "circlemarker": False,
414
+ "marker": False,
415
+ "polygon": True
416
+ }
417
+ ).add_to(m)
418
+ return m._repr_html_()
419
+
420
+ def select_polygon(geojson_file):
421
+ global current_polygon_data
422
+ if not geojson_file:
423
+ return "❌ No GeoJSON file uploaded. Please draw a polygon, export it, and upload the file."
424
+
425
+ try:
426
+ with open(geojson_file.name, 'r') as f:
427
+ geojson_data = json.load(f)
428
+
429
+ if geojson_data.get('type') == 'FeatureCollection':
430
+ features = geojson_data.get('features', [])
431
+ for feature in features:
432
+ if feature.get('geometry', {}).get('type') == 'Polygon':
433
+ current_polygon_data = feature
434
+ return "βœ… Polygon selected successfully!"
435
+ return "❌ No valid polygon found in the GeoJSON file."
436
+ except Exception as e:
437
+ return f"Error reading GeoJSON file: {str(e)}"
438
+
439
+ def process_polygon_prediction(spacing_m, province, season, date, geojson_file):
440
+ global current_polygon_data
441
+
442
+ try:
443
+ datetime.strptime(date, "%d/%m/%Y")
444
+ except ValueError:
445
+ return "Invalid date format. Please use DD/MM/YYYY.", None, None
446
+
447
+ if not current_polygon_data:
448
+ return "❌ No polygon selected. Please draw a polygon, export it as GeoJSON, and upload it.", None, None
449
+
450
+ try:
451
+ polygon = shape(current_polygon_data['geometry'])
452
+ except Exception as e:
453
+ return f"Error parsing polygon: {str(e)}", None, None
454
+
455
+ spacing_deg = spacing_m / 111320.0
456
+ points = generate_grid_points(polygon, spacing_deg)
457
+ print(f"Number of points selected: {len(points)}")
458
+
459
+ if not points:
460
+ return "No points generated within the polygon. Try increasing the spacing.", None, None
461
+
462
+ predicted_points = []
463
+ static_features = {
464
+ "Province": province,
465
+ "Season": season,
466
+ "Date": date
467
+ }
468
+
469
+ for i, point in enumerate(points, 1):
470
+ indices = get_indices(point["latitude"], point["longitude"], date)
471
+ print(f"GEE started for point {i} at lat={point['latitude']}, lon={point['longitude']}")
472
+ if indices:
473
+ print(f"GEE values fetched for point {i}")
474
+ static_features.update({
475
+ "NDVI": indices["NDVI"],
476
+ "NDWI": indices["NDWI"],
477
+ "EVI": indices["EVI"],
478
+ "GNDVI": indices["GNDVI"],
479
+ "SAVI": indices["SAVI"],
480
+ "Red": indices["Red"],
481
+ "Green": indices["Green"],
482
+ "Blue": indices["Blue"],
483
+ "NIR": indices["NIR"],
484
+ "SWIR": indices["SWIR"]
485
+ })
486
+ crop = predict_crop_description(point, static_features, scaler, feature_columns, province, season)
487
+ point.update({
488
+ "crop": crop,
489
+ "NDVI": indices["NDVI"],
490
+ "NDWI": indices["NDWI"],
491
+ "EVI": indices["EVI"],
492
+ "GNDVI": indices["GNDVI"],
493
+ "SAVI": indices["SAVI"]
494
+ })
495
+ predicted_points.append(point)
496
+
497
+ if not predicted_points:
498
+ return "No valid data found for any grid points.", None, None
499
+
500
+ pred_df = pd.DataFrame(predicted_points)
501
+ unique_crops = pred_df['crop'].unique()
502
+ crop_colors = assign_crop_colors(unique_crops)
503
+
504
+ center_lat = sum(pt["latitude"] for pt in predicted_points) / len(predicted_points)
505
+ center_lon = sum(pt["longitude"] for pt in predicted_points) / len(predicted_points)
506
+ pred_map = folium.Map(location=[center_lat, center_lon], zoom_start=12)
507
+
508
+ folium.GeoJson(
509
+ current_polygon_data,
510
+ style_function=lambda x: {'color': 'red', 'weight': 3, 'fill': False}
511
+ ).add_to(pred_map)
512
+
513
+ for pt in predicted_points:
514
+ crop_type = pt.get("crop", "Other")
515
+ color = crop_colors.get(crop_type, "#808080")
516
+ folium.Circle(
517
+ location=[pt["latitude"], pt["longitude"]],
518
+ radius=spacing_m/2,
519
+ color='black',
520
+ weight=1,
521
+ fill=True,
522
+ fillColor=color,
523
+ fillOpacity=0.7,
524
+ popup=f"Crop: {crop_type}<br>Lat: {pt['latitude']:.4f}<br>Lon: {pt['longitude']:.4f}<br>NDVI: {pt['NDVI']:.3f}<br>NDWI: {pt['NDWI']:.3f}<br>EVI: {pt['EVI']:.3f}<br>GNDVI: {pt['GNDVI']:.3f}<br>SAVI: {pt['SAVI']:.3f}",
525
+ tooltip=crop_type
526
+ ).add_to(pred_map)
527
+
528
+ legend_html = '''
529
+ <div style="position: fixed; bottom: 50px; left: 50px; width: 180px;
530
+ background-color: white; border:2px solid grey; z-index:9999;
531
+ font-size:14px; padding: 10px; border-radius: 5px;">
532
+ <p style="margin: 0 0 10px 0; font-weight:bold;">🌾 Crop Types</p>
533
+ '''
534
+ for crop in unique_crops:
535
+ color = crop_colors[crop]
536
+ legend_html += f'<p style="margin: 5px 0;"><span style="color:{color}; font-size:16px;">●</span> {crop}</p>'
537
+ legend_html += '</div>'
538
+ pred_map.get_root().html.add_child(folium.Element(legend_html))
539
+
540
+ crop_stats = pred_df['crop'].value_counts()
541
+ stats = f"βœ… Polygon processed successfully!\n\nCrop Distribution (Province: {province}, Season: {season}):\n"
542
+ for crop, count in crop_stats.items():
543
+ percentage = (count / len(predicted_points)) * 100
544
+ stats += f"{crop}: {count} points ({percentage:.1f}%)\n"
545
+ for index in ['NDVI', 'NDWI', 'EVI', 'GNDVI', 'SAVI']:
546
+ avg = pred_df[index].mean()
547
+ stats += f"Average {index}: {avg:.3f}\n"
548
+
549
+ csv_file_path = f"crop_predictions_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
550
+ try:
551
+ pred_df.to_csv(csv_file_path, index=False)
552
+ except Exception as e:
553
+ print(f"Error creating CSV file: {str(e)}")
554
+ csv_file_path = None
555
+
556
+ return stats, pred_map._repr_html_(), csv_file_path
557
+
558
+ # --- Instance Interface ---
559
+ def predict_instance(province, season, latitude, longitude, date, ndvi, ndwi, ndbi, red, green, blue, nir, swir):
560
+ static_features = {
561
+ "Province": province,
562
+ "Season": season,
563
+ "NDVI": ndvi,
564
+ "NDWI": ndwi,
565
+ "NDBI": ndbi,
566
+ "Red": red,
567
+ "Green": green,
568
+ "Blue": blue,
569
+ "NIR": nir,
570
+ "SWIR": swir,
571
+ "Date": date
572
+ }
573
+ crop = predict_crop_description({"latitude": latitude, "longitude": longitude}, static_features, scaler, feature_columns, province, season)
574
+ return f"{crop}"
575
+
576
+ from pathlib import Path
577
+ import gradio as gr
578
+
579
+ # Sample file paths
580
+ sample_dir = Path("samples") # Ensure this directory exists with .tif files
581
+ sample_files = {
582
+ "Sample 1": sample_dir / "sample1.tif",
583
+ "Sample 2": sample_dir / "sample2.tif"
584
+ }
585
+
586
+ # Function to simulate upload when sample is clicked
587
+ def load_sample_and_predict(sample_name, province, season, date):
588
+ file_path = sample_files[sample_name]
589
+ return process_upload(file_path, province, season, date)
590
+
591
+ # --- Gradio Interface ---
592
+ with gr.Blocks(title="Crop Predictor", theme=gr.themes.Soft()) as demo:
593
+ gr.Markdown("# 🌾 Crop Predictor")
594
+
595
+ with gr.Tabs():
596
+ with gr.TabItem("πŸ“€ Upload"):
597
+ gr.Markdown("Upload a .tiff or .tif file with bands [r, g, b, rededge, nir, swr1, swr2]")
598
+
599
+ file_input = gr.File(label="Upload .tiff/.tif file", file_types=[".tiff", ".tif"])
600
+
601
+ with gr.Row():
602
+ province = gr.Textbox(label="Province", value="Punjab")
603
+ season = gr.Textbox(label="Season", value="Rabi")
604
+
605
+ with gr.Row():
606
+ date = gr.Textbox(label="Date (DD/MM/YYYY)", value="10/01/2023")
607
+
608
+ upload_btn = gr.Button("πŸ” Predict", variant="primary")
609
+ output_stats = gr.Textbox(label="Prediction Statistics", lines=10)
610
+ output_image = gr.Image(label="Prediction Result")
611
+
612
+ upload_btn.click(
613
+ fn=process_upload,
614
+ inputs=[file_input, province, season, date],
615
+ outputs=[output_stats, output_image]
616
+ )
617
+
618
+ # -- Add Sample File Buttons Here --
619
+ gr.Markdown("### Or try with a sample file:")
620
+ with gr.Row():
621
+ for name in sample_files:
622
+ gr.Button(name).click(
623
+ fn=load_sample_and_predict,
624
+ inputs=[gr.State(name), province, season, date],
625
+ outputs=[output_stats, output_image]
626
+ )
627
+
628
+ with gr.TabItem("πŸ—ΊοΈ Map"):
629
+ gr.Markdown("""
630
+ ## Interactive Polygon Crop Prediction
631
+
632
+ **Instructions:**
633
+ 1. Draw a polygon on the map below using the polygon tool.
634
+ 2. Click the "Export" button on the map to save the polygon as a GeoJSON file (polygon.geojson).
635
+ 3. Upload the exported GeoJSON file using the file input below.
636
+ 4. Adjust settings and click "πŸ” Predict" to process.
637
+ """)
638
+
639
+ map_html = gr.HTML(create_interactive_map, label="Draw Your Polygon Here")
640
+
641
+ with gr.Row():
642
+ geojson_input = gr.File(label="Upload Exported GeoJSON File")
643
+ select_btn = gr.Button("🎯 Select My Polygon", variant="secondary")
644
+ spacing = gr.Slider(
645
+ label="Grid Spacing (meters)",
646
+ minimum=10, maximum=1000, value=30, step=100
647
+ )
648
+
649
+ with gr.Row():
650
+ province_map = gr.Textbox(label="Province", value="Punjab")
651
+ season_map = gr.Textbox(label="Season", value="Multan")
652
+ date_map = gr.Textbox(label="Date (DD/MM/YYYY)", value="10/01/2023")
653
+
654
+ polygon_status = gr.Textbox(
655
+ label="Selection Status",
656
+ value="⏳ Please draw a polygon, export it, and upload the GeoJSON file.",
657
+ interactive=False
658
+ )
659
+
660
+ predict_btn = gr.Button("πŸ” Predict Crops", variant="primary", size="lg")
661
+
662
+ output_map_stats = gr.Textbox(label="Prediction Results", lines=10)
663
+ output_map = gr.HTML(label="Crop Prediction Map")
664
+ output_csv = gr.File(label="πŸ“₯ Download Results CSV")
665
+
666
+ select_btn.click(
667
+ fn=select_polygon,
668
+ inputs=[geojson_input],
669
+ outputs=polygon_status
670
+ )
671
+
672
+ predict_btn.click(
673
+ fn=process_polygon_prediction,
674
+ inputs=[spacing, province_map, season_map, date_map, geojson_input],
675
+ outputs=[output_map_stats, output_map, output_csv]
676
+ )
677
+
678
+ with gr.TabItem("πŸ“Š Instance"):
679
+ gr.Markdown("## Single Point Prediction")
680
+ gr.Markdown("Enter features manually for a single point prediction")
681
+
682
+ with gr.Row():
683
+ province_inst = gr.Textbox(label="Province", value="Punjab")
684
+ season_inst = gr.Textbox(label="Season", value="Rabi")
685
+
686
+ with gr.Row():
687
+ latitude_inst = gr.Number(label="Latitude", value=30.809)
688
+ longitude_inst = gr.Number(label="Longitude", value=73.450)
689
+ date_inst = gr.Textbox(label="Date (DD/MM/YYYY)", value="10/01/2023")
690
+
691
+ gr.Markdown("### Spectral Indices")
692
+ with gr.Row():
693
+ ndvi_inst = gr.Number(label="NDVI", value=0.65)
694
+ ndwi_inst = gr.Number(label="NDWI", value=-2.0)
695
+ ndbi_inst = gr.Number(label="NDBI", value=0.10)
696
+
697
+ gr.Markdown("### Band Values")
698
+ with gr.Row():
699
+ red_inst = gr.Number(label="Red", value=678)
700
+ green_inst = gr.Number(label="Green", value=732)
701
+ blue_inst = gr.Number(label="Blue", value=620)
702
+
703
+ with gr.Row():
704
+ nir_inst = gr.Number(label="NIR", value=3000)
705
+ swir_inst = gr.Number(label="SWIR", value=1800)
706
+
707
+ instance_btn = gr.Button("πŸ” Predict", variant="primary")
708
+ output_instance = gr.Textbox(label="Prediction Result", lines=3)
709
+
710
+ instance_btn.click(
711
+ fn=predict_instance,
712
+ inputs=[province_inst, season_inst, latitude_inst, longitude_inst,
713
+ date_inst, ndvi_inst, ndwi_inst, ndbi_inst, red_inst,
714
+ green_inst, blue_inst, nir_inst, swir_inst],
715
+ outputs=output_instance
716
+ )
717
+
718
+ if __name__ == "__main__":
719
  demo.launch(share=True)