prshntdxt commited on
Commit
c82cafe
·
1 Parent(s): 4581b61

Deploy Forest Segmentation API with LFS

Browse files
.dockerignore ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+ *.pyc
3
+ *.pyo
4
+ *.pyd
5
+ .Python
6
+ *.egg-info/
7
+ dist/
8
+ build/
9
+ .venv
10
+ venv/
11
+ env/
12
+ server/
13
+ .git
14
+ .gitignore
15
+ .env
16
+ .env.local
17
+ *.md
18
+ *.txt
19
+ !requirements.txt
20
+ logs/
21
+ *.hdf5
22
+ .DS_Store
23
+ .vscode/
24
+ .idea/
25
+ *.log
26
+ test_*.py
27
+ *_test.py
Dockerfile ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use official Python runtime as base image
2
+ FROM python:3.11-slim
3
+
4
+ # Set working directory in container
5
+ WORKDIR /app
6
+
7
+ # Install system dependencies (OpenCV, HDF5 support)
8
+ RUN apt-get update && apt-get install -y \
9
+ libsm6 \
10
+ libxext6 \
11
+ libxrender-dev \
12
+ libhdf5-dev \
13
+ pkg-config \
14
+ && rm -rf /var/lib/apt/lists/*
15
+
16
+ # Copy requirements file
17
+ COPY requirements.txt .
18
+
19
+ # Install Python dependencies
20
+ RUN pip install --no-cache-dir -r requirements.txt
21
+
22
+ # Copy application code
23
+ COPY main.py .
24
+ COPY model.py .
25
+ COPY schemas.py .
26
+ COPY inference/ ./inference/
27
+
28
+ # Copy pre-trained model (use .keras format only)
29
+ COPY models/Forest_Segmentation_Best.keras ./models/
30
+
31
+ # Create logs directory
32
+ RUN mkdir -p logs
33
+
34
+ # Expose port (Hugging Face Spaces default)
35
+ EXPOSE 7860
36
+
37
+ # Set environment for production
38
+ ENV PYTHONUNBUFFERED=1
39
+ ENV PORT=7860
40
+ ENV HOST=0.0.0.0
41
+
42
+ # Health check
43
+ HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \
44
+ CMD python -c "import requests; requests.get('http://localhost:7860/health')" || exit 1
45
+
46
+ # Run FastAPI application with uvicorn
47
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Entry point for FastAPI application.
4
+ Runs uvicorn server on configurable host/port.
5
+ """
6
+
7
+ import os
8
+ import uvicorn
9
+ from main import app
10
+
11
+ if __name__ == "__main__":
12
+ host = os.getenv("HOST", "0.0.0.0")
13
+ port = int(os.getenv("PORT", "7860"))
14
+
15
+ print(f"Starting Forest Segmentation API on {host}:{port}")
16
+
17
+ uvicorn.run(
18
+ app,
19
+ host=host,
20
+ port=port,
21
+ log_level="info"
22
+ )
inference/__pycache__/forest.cpython-312.pyc ADDED
Binary file (14.4 kB). View file
 
inference/forest.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tensorflow.keras.models import load_model
2
+ import numpy as np
3
+ import base64
4
+ import logging
5
+
6
+ MODEL_PATH = "models/Forest_Segmentation_Best.keras"
7
+ model = None
8
+ EPS = 1e-6
9
+
10
+ # Setup logging
11
+ logger = logging.getLogger("forest_segmentation.inference")
12
+
13
+ def load():
14
+ global model
15
+ if model is None:
16
+ logger.info("[INFERENCE] Loading model from: " + MODEL_PATH)
17
+ model = load_model(MODEL_PATH, compile=False)
18
+ logger.info("[INFERENCE] Model loaded successfully")
19
+
20
+ def decode_band_float32(b64):
21
+ """Decode base64-encoded float32 band data to array"""
22
+ raw = base64.b64decode(b64)
23
+ arr = np.frombuffer(raw, dtype=np.float32)
24
+ side = int(np.sqrt(arr.size))
25
+ return arr.reshape((side, side))
26
+
27
+ def validate_landsat_data(bands_dict):
28
+ """
29
+ Validate that input data matches Landsat 8 Collection 2 Level 2 format
30
+ Expected range: [-0.2, 0.6] for optical bands, [-1, 1] for indices
31
+ """
32
+ for band_name, data in bands_dict.items():
33
+ if data.ndim != 2:
34
+ raise ValueError(f"{band_name}: Expected 2D array, got shape {data.shape}")
35
+ if data.dtype != np.float32:
36
+ data = data.astype(np.float32)
37
+ return bands_dict
38
+
39
+ def ndvi(red, nir):
40
+ """Normalized Difference Vegetation Index"""
41
+ return (nir - red) / (nir + red + EPS)
42
+
43
+ def ndwi(green, nir):
44
+ """Normalized Difference Water Index"""
45
+ return (green - nir) / (green + nir + EPS)
46
+
47
+ def nbr(nir, swir2):
48
+ """Normalized Burn Ratio"""
49
+ return (nir - swir2) / (nir + swir2 + EPS)
50
+
51
+ def analyze_input_bands(bands):
52
+ """Analyze input bands and return statistics"""
53
+ stats = {}
54
+
55
+ logger.info("[ANALYSIS] === INPUT BAND ANALYSIS ===")
56
+
57
+ for band_name in ['Blue', 'Green', 'Red', 'NIR', 'SWIR1', 'SWIR2', 'NDVI', 'NDWI', 'NBR']:
58
+ if band_name in bands:
59
+ data = bands[band_name]
60
+ stats[band_name] = {
61
+ "min": float(data.min()),
62
+ "max": float(data.max()),
63
+ "mean": float(data.mean()),
64
+ "std": float(data.std())
65
+ }
66
+ logger.info(f"[ANALYSIS] {band_name}: min={stats[band_name]['min']:.4f}, max={stats[band_name]['max']:.4f}, mean={stats[band_name]['mean']:.4f}")
67
+
68
+ # Analyze vegetation coverage
69
+ if 'NDVI' in bands:
70
+ ndvi_data = bands['NDVI']
71
+ veg_pixels = np.sum(ndvi_data > 0.5)
72
+ veg_pct = (veg_pixels / ndvi_data.size) * 100
73
+ logger.info(f"[ANALYSIS] NDVI > 0.5 (vegetation): {veg_pct:.2f}% of pixels")
74
+ stats['vegetation_coverage_pct'] = veg_pct
75
+
76
+ return stats
77
+
78
+ def preprocess_for_model(bands, clip_optical=False, clip_indices=False):
79
+ """
80
+ Preprocess bands to match model training expectations
81
+
82
+ Args:
83
+ bands: Dictionary of band arrays
84
+ clip_optical: If True, clip optical bands to [-0.2, 0.6]
85
+ clip_indices: If True, clip indices to [-1, 1]
86
+
87
+ Returns:
88
+ Preprocessed bands dictionary
89
+ """
90
+ processed = {}
91
+
92
+ if clip_optical:
93
+ logger.info("[PREPROCESS] Clipping optical bands to [-0.2, 0.6]")
94
+ for name in ['Blue', 'Green', 'Red', 'NIR', 'SWIR1', 'SWIR2']:
95
+ if name in bands:
96
+ processed[name] = np.clip(bands[name], -0.2, 0.6)
97
+ else:
98
+ processed[name] = bands.get(name)
99
+ else:
100
+ for name in ['Blue', 'Green', 'Red', 'NIR', 'SWIR1', 'SWIR2']:
101
+ if name in bands:
102
+ processed[name] = bands[name]
103
+
104
+ if clip_indices:
105
+ logger.info("[PREPROCESS] Clipping indices to [-1.0, 1.0]")
106
+ for name in ['NDVI', 'NDWI', 'NBR']:
107
+ if name in bands:
108
+ processed[name] = np.clip(bands[name], -1.0, 1.0)
109
+ else:
110
+ processed[name] = bands.get(name)
111
+ else:
112
+ for name in ['NDVI', 'NDWI', 'NBR']:
113
+ if name in bands:
114
+ processed[name] = bands[name]
115
+
116
+ return processed
117
+
118
+ def build_input_tensor(bands):
119
+ """
120
+ Build 9-channel input tensor from Landsat 8 bands
121
+
122
+ Expected band dict keys:
123
+ - Blue, Green, Red: Optical bands (indices 0-2)
124
+ - NIR, SWIR1, SWIR2: Infrared bands (indices 3-5)
125
+ - NDVI, NDWI, NBR: Pre-calculated or computed indices (indices 6-8)
126
+
127
+ Returns: (1, 256, 256, 9) array ready for model inference
128
+ """
129
+ # Extract optical bands
130
+ blue = decode_band_float32(bands["Blue"]) if isinstance(bands["Blue"], str) else bands["Blue"]
131
+ green = decode_band_float32(bands["Green"]) if isinstance(bands["Green"], str) else bands["Green"]
132
+ red = decode_band_float32(bands["Red"]) if isinstance(bands["Red"], str) else bands["Red"]
133
+ nir = decode_band_float32(bands["NIR"]) if isinstance(bands["NIR"], str) else bands["NIR"]
134
+ swir1 = decode_band_float32(bands["SWIR1"]) if isinstance(bands["SWIR1"], str) else bands["SWIR1"]
135
+ swir2 = decode_band_float32(bands["SWIR2"]) if isinstance(bands["SWIR2"], str) else bands["SWIR2"]
136
+
137
+ # Use pre-calculated indices if provided, otherwise compute them
138
+ if isinstance(bands.get("NDVI"), str) or isinstance(bands.get("NDVI"), np.ndarray):
139
+ ndvi_map = decode_band_float32(bands["NDVI"]) if isinstance(bands["NDVI"], str) else bands["NDVI"]
140
+ else:
141
+ ndvi_map = ndvi(red, nir)
142
+
143
+ if isinstance(bands.get("NDWI"), str) or isinstance(bands.get("NDWI"), np.ndarray):
144
+ ndwi_map = decode_band_float32(bands["NDWI"]) if isinstance(bands["NDWI"], str) else bands["NDWI"]
145
+ else:
146
+ ndwi_map = ndwi(green, nir)
147
+
148
+ if isinstance(bands.get("NBR"), str) or isinstance(bands.get("NBR"), np.ndarray):
149
+ nbr_map = decode_band_float32(bands["NBR"]) if isinstance(bands["NBR"], str) else bands["NBR"]
150
+ else:
151
+ nbr_map = nbr(nir, swir2)
152
+
153
+ # Stack into 9-channel tensor: (H, W, 9)
154
+ stacked = np.stack([
155
+ blue,
156
+ green,
157
+ red,
158
+ nir,
159
+ swir1,
160
+ swir2,
161
+ ndvi_map,
162
+ ndwi_map,
163
+ nbr_map
164
+ ], axis=-1).astype(np.float32)
165
+
166
+ # Validate data range matches training expectations
167
+ opt_min, opt_max = np.min(stacked[..., :6]), np.max(stacked[..., :6])
168
+ if opt_min < -0.3 or opt_max > 1.0:
169
+ logger.warning(f"[BUILD] WARNING: Optical bands range [{opt_min:.4f}, {opt_max:.4f}] outside expected [-0.2, 0.6]")
170
+
171
+ # Add batch dimension: (1, H, W, 9)
172
+ stacked = np.expand_dims(stacked, axis=0)
173
+ return stacked
174
+
175
+ def predict_forest(bands, debug=False, clip_optical=False, clip_indices=False):
176
+ """
177
+ Predict forest segmentation mask from Landsat 8 9-band input
178
+
179
+ Args:
180
+ bands: Dictionary with keys: Blue, Green, Red, NIR, SWIR1, SWIR2, NDVI, NDWI, NBR
181
+ debug: If True, return detailed debug statistics
182
+ clip_optical: If True, clip optical bands to [-0.2, 0.6]
183
+ clip_indices: If True, clip indices to [-1, 1]
184
+
185
+ Returns:
186
+ Dictionary with mask, confidence scores, and optional debug data
187
+ """
188
+ load()
189
+
190
+ # Analyze input
191
+ logger.info("[PREDICT] Starting prediction...")
192
+ input_stats = analyze_input_bands(bands)
193
+
194
+ # Preprocess if requested
195
+ if clip_optical or clip_indices:
196
+ logger.info("[PREDICT] Applying preprocessing (clip_optical={}, clip_indices={})...".format(clip_optical, clip_indices))
197
+ bands = preprocess_for_model(bands, clip_optical=clip_optical, clip_indices=clip_indices)
198
+
199
+ # Build input tensor
200
+ logger.info("[PREDICT] Building input tensor...")
201
+ x = build_input_tensor(bands)
202
+
203
+ # Run inference
204
+ logger.info("[PREDICT] Running model inference...")
205
+ pred = model.predict(x, verbose=0)[0, :, :, 0] # Extract (H, W) from (1, H, W, 1)
206
+
207
+ # Analyze output
208
+ logger.info("[PREDICT] === RAW MODEL OUTPUT ===")
209
+ logger.info(f"[PREDICT] Output shape: {pred.shape}, dtype: {pred.dtype}")
210
+ logger.info(f"[PREDICT] Output range: [{pred.min():.4f}, {pred.max():.4f}]")
211
+ logger.info(f"[PREDICT] Output mean: {pred.mean():.4f}, std: {pred.std():.4f}")
212
+ logger.info(f"[PREDICT] Pixels > 0.5: {np.sum(pred > 0.5):,} / {pred.size:,} ({100*np.sum(pred > 0.5)/pred.size:.2f}%)")
213
+ logger.info(f"[PREDICT] Pixels > 0.8: {np.sum(pred > 0.8):,} / {pred.size:,}")
214
+
215
+ # Generate binary mask
216
+ mask = (pred > 0.5).astype(np.uint8) * 255
217
+
218
+ # Calculate statistics
219
+ forest_confidence = float(np.mean(pred[pred > 0.5])) if np.any(pred > 0.5) else 0.0
220
+ forest_percentage = float((pred > 0.5).sum() / pred.size * 100)
221
+
222
+ result = {
223
+ "mask": mask.tolist(),
224
+ "forest_confidence": forest_confidence,
225
+ "forest_percentage": forest_percentage,
226
+ "mean_prediction": float(pred.mean()),
227
+ "classes": ["forest", "non-forest"],
228
+ "model_version": "landsat8_trained"
229
+ }
230
+
231
+ if debug:
232
+ logger.info("[PREDICT] Adding debug information...")
233
+ result["debug"] = {
234
+ "input_stats": input_stats,
235
+ "output_distribution": {
236
+ "min": float(pred.min()),
237
+ "max": float(pred.max()),
238
+ "mean": float(pred.mean()),
239
+ "std": float(pred.std()),
240
+ "percentile_10": float(np.percentile(pred, 10)),
241
+ "percentile_25": float(np.percentile(pred, 25)),
242
+ "percentile_50": float(np.percentile(pred, 50)),
243
+ "percentile_75": float(np.percentile(pred, 75)),
244
+ "percentile_90": float(np.percentile(pred, 90)),
245
+ "histogram": {
246
+ "0.0-0.1": int(np.sum((pred >= 0.0) & (pred < 0.1))),
247
+ "0.1-0.2": int(np.sum((pred >= 0.1) & (pred < 0.2))),
248
+ "0.2-0.3": int(np.sum((pred >= 0.2) & (pred < 0.3))),
249
+ "0.3-0.4": int(np.sum((pred >= 0.3) & (pred < 0.4))),
250
+ "0.4-0.5": int(np.sum((pred >= 0.4) & (pred < 0.5))),
251
+ "0.5-0.6": int(np.sum((pred >= 0.5) & (pred < 0.6))),
252
+ "0.6-0.7": int(np.sum((pred >= 0.6) & (pred < 0.7))),
253
+ "0.7-0.8": int(np.sum((pred >= 0.7) & (pred < 0.8))),
254
+ "0.8-0.9": int(np.sum((pred >= 0.8) & (pred < 0.9))),
255
+ "0.9-1.0": int(np.sum((pred >= 0.9) & (pred <= 1.0)))
256
+ }
257
+ }
258
+ }
259
+
260
+ logger.info("[PREDICT] Forest prediction: {:.2f}%".format(forest_percentage))
261
+ return result
main.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # main.py
2
+
3
+ from fastapi import FastAPI, HTTPException, Request
4
+ from fastapi.responses import JSONResponse
5
+ import tensorflow as tf
6
+ import numpy as np
7
+ import base64
8
+ import logging
9
+ import os
10
+ import sys
11
+ import time
12
+ from datetime import datetime
13
+ from logging.handlers import RotatingFileHandler
14
+
15
+ from inference.forest import predict_forest, build_input_tensor
16
+ from schemas import PredictRequest, PredictResponse
17
+
18
+ # =============================================================================
19
+ # LOGGING CONFIGURATION
20
+ # =============================================================================
21
+
22
+ os.makedirs("logs", exist_ok=True)
23
+
24
+ logger = logging.getLogger("forest_segmentation")
25
+ logger.setLevel(logging.DEBUG)
26
+
27
+ console_handler = logging.StreamHandler(sys.stdout)
28
+ console_handler.setLevel(logging.DEBUG)
29
+ console_handler.setFormatter(
30
+ logging.Formatter(
31
+ "%(asctime)s | %(levelname)-8s | %(message)s",
32
+ datefmt="%Y-%m-%d %H:%M:%S"
33
+ )
34
+ )
35
+
36
+ file_handler = RotatingFileHandler(
37
+ "logs/server.log", maxBytes=10_000_000, backupCount=5, encoding="utf-8"
38
+ )
39
+ file_handler.setFormatter(console_handler.formatter)
40
+
41
+ logger.addHandler(console_handler)
42
+ logger.addHandler(file_handler)
43
+
44
+ logger.info("=" * 80)
45
+ logger.info("FOREST SEGMENTATION SERVER STARTING")
46
+ logger.info("=" * 80)
47
+
48
+ # =============================================================================
49
+ # INVERSION DETECTION
50
+ # =============================================================================
51
+
52
+ def detect_inversion(image_stack, confidence_map, ndvi_threshold=0.3):
53
+ """
54
+ Detect if model output is inverted using NDVI correlation.
55
+ image_stack: (H, W, 9)
56
+ confidence_map: (H, W)
57
+ """
58
+ ndvi = image_stack[:, :, 6] # NDVI channel
59
+
60
+ vegetation_mask = ndvi > ndvi_threshold
61
+
62
+ veg_conf = (
63
+ confidence_map[vegetation_mask].mean()
64
+ if vegetation_mask.any() else 0.5
65
+ )
66
+
67
+ non_veg_conf = (
68
+ confidence_map[~vegetation_mask].mean()
69
+ if (~vegetation_mask).any() else 0.5
70
+ )
71
+
72
+ is_inverted = non_veg_conf > veg_conf
73
+ correlation = veg_conf - non_veg_conf
74
+
75
+ return bool(is_inverted), float(correlation)
76
+
77
+ # =============================================================================
78
+ # FASTAPI APP
79
+ # =============================================================================
80
+
81
+ app = FastAPI(
82
+ title="Forest Segmentation API",
83
+ description="Landsat 8 Forest Segmentation",
84
+ version="1.0.0"
85
+ )
86
+
87
+ IMG_SIZE = 256
88
+ LANDSAT_BANDS = [
89
+ "Blue", "Green", "Red",
90
+ "NIR", "SWIR1", "SWIR2",
91
+ "NDVI", "NDWI", "NBR"
92
+ ]
93
+
94
+ # =============================================================================
95
+ # MIDDLEWARE
96
+ # =============================================================================
97
+
98
+ @app.middleware("http")
99
+ async def log_requests(request: Request, call_next):
100
+ start = time.time()
101
+ response = await call_next(request)
102
+ duration = time.time() - start
103
+ logger.info(
104
+ f"{request.method} {request.url.path} | "
105
+ f"{response.status_code} | {duration:.3f}s"
106
+ )
107
+ return response
108
+
109
+ # =============================================================================
110
+ # HEALTH
111
+ # =============================================================================
112
+
113
+ @app.get("/health")
114
+ def health():
115
+ return {
116
+ "status": "healthy",
117
+ "timestamp": datetime.utcnow().isoformat()
118
+ }
119
+
120
+ # =============================================================================
121
+ # PREDICT ENDPOINT (FIXED - CONTINUOUS VALUES)
122
+ # =============================================================================
123
+
124
+ @app.post("/predict", response_model=PredictResponse)
125
+ def predict(payload: PredictRequest):
126
+
127
+ try:
128
+ logger.info("[PREDICT] Request received")
129
+
130
+ if not payload.bands:
131
+ raise ValueError("No bands provided")
132
+
133
+ # ---------------------------------------------------------------------
134
+ # Decode bands
135
+ # ---------------------------------------------------------------------
136
+ decoded_bands = {}
137
+
138
+ for band, data in payload.bands.items():
139
+ if isinstance(data, str):
140
+ raw = base64.b64decode(data)
141
+ arr = np.frombuffer(raw, dtype=np.float32)
142
+ side = int(np.sqrt(arr.size))
143
+ decoded_bands[band] = arr.reshape((side, side))
144
+ else:
145
+ decoded_bands[band] = np.array(data, dtype=np.float32)
146
+
147
+ logger.info(f"[PREDICT] Decoded {len(decoded_bands)} bands")
148
+
149
+ # ---------------------------------------------------------------------
150
+ # Build input tensor
151
+ # ---------------------------------------------------------------------
152
+ input_tensor = build_input_tensor(decoded_bands) # (1, H, W, 9)
153
+ input_stack = input_tensor[0] # (H, W, 9)
154
+
155
+ # ---------------------------------------------------------------------
156
+ # Run model (raw confidence)
157
+ # ---------------------------------------------------------------------
158
+ model = tf.keras.models.load_model(
159
+ "models/Forest_Segmentation_Best.keras",
160
+ compile=False
161
+ )
162
+
163
+ confidence_map = model.predict(
164
+ input_tensor, verbose=0
165
+ )[0, :, :, 0]
166
+
167
+ # Log raw model output stats
168
+ logger.info(
169
+ f"[MODEL OUTPUT] Raw confidence: min={confidence_map.min():.4f}, "
170
+ f"max={confidence_map.max():.4f}, mean={confidence_map.mean():.4f}"
171
+ )
172
+
173
+ # ---------------------------------------------------------------------
174
+ # Inversion detection & correction
175
+ # ---------------------------------------------------------------------
176
+ is_inverted, corr = detect_inversion(
177
+ input_stack, confidence_map
178
+ )
179
+
180
+ if is_inverted:
181
+ logger.warning(
182
+ f"[INVERSION] Detected | NDVI correlation={corr:.4f} | FIX APPLIED"
183
+ )
184
+ corrected_conf = 1.0 - confidence_map
185
+ else:
186
+ logger.info(
187
+ f"[INVERSION] Not detected | NDVI correlation={corr:.4f}"
188
+ )
189
+ corrected_conf = confidence_map
190
+
191
+ # ---------------------------------------------------------------------
192
+ # Create masks (CONTINUOUS values for density visualization)
193
+ # ---------------------------------------------------------------------
194
+ # Use continuous confidence scaled to 0-255 (NOT binary!)
195
+ mask_255 = (corrected_conf * 255).astype(np.uint8)
196
+ inverted_mask_255 = (255 - mask_255).astype(np.uint8)
197
+
198
+ # Calculate stats using threshold for percentage
199
+ forest_percentage = float((corrected_conf > 0.5).sum() / corrected_conf.size * 100)
200
+ forest_confidence = float(corrected_conf.mean())
201
+
202
+ # Log mask stats to verify continuous values
203
+ logger.info(
204
+ f"[MASK] Range: [{mask_255.min()}, {mask_255.max()}] | "
205
+ f"Unique values: {len(np.unique(mask_255))}"
206
+ )
207
+
208
+ logger.info(
209
+ f"[PREDICT] Forest={forest_percentage:.2f}% | "
210
+ f"Confidence={forest_confidence:.4f}"
211
+ )
212
+
213
+ # ---------------------------------------------------------------------
214
+ # Response
215
+ # ---------------------------------------------------------------------
216
+ return {
217
+ "mask": mask_255.flatten().tolist(),
218
+ "inverted_mask": inverted_mask_255.flatten().tolist(),
219
+ "forest_percentage": forest_percentage,
220
+ "forest_confidence": forest_confidence,
221
+ "mean_prediction": forest_confidence,
222
+ "classes": {"forest": 1, "non_forest": 0},
223
+ "model_info": {
224
+ "name": "Forest_Segmentation_Best",
225
+ "bands": LANDSAT_BANDS
226
+ },
227
+ "debug": {
228
+ "was_inverted": is_inverted,
229
+ "inversion_correlation": corr,
230
+ "mask_min": int(mask_255.min()),
231
+ "mask_max": int(mask_255.max()),
232
+ "unique_values": int(len(np.unique(mask_255)))
233
+ }
234
+ }
235
+
236
+ except ValueError as e:
237
+ logger.error(f"[PREDICT] Validation error: {e}")
238
+ raise HTTPException(status_code=400, detail=str(e))
239
+
240
+ except Exception as e:
241
+ logger.exception("[PREDICT] Inference failed")
242
+ raise HTTPException(status_code=500, detail=str(e))
243
+
244
+ # =============================================================================
245
+ # STARTUP / SHUTDOWN
246
+ # =============================================================================
247
+
248
+ @app.on_event("startup")
249
+ async def startup():
250
+ logger.info("=" * 80)
251
+ logger.info("SERVER READY")
252
+ logger.info("=" * 80)
253
+
254
+ @app.on_event("shutdown")
255
+ async def shutdown():
256
+ logger.info("=" * 80)
257
+ logger.info("SERVER SHUTDOWN")
258
+ logger.info("=" * 80)
model.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorflow as tf
3
+ from tensorflow.keras.models import load_model
4
+ import base64
5
+ import io
6
+
7
+ MODEL_PATH = "models/Forest_Segmentation_Best.keras"
8
+ EPS = 1e-6
9
+ model = None
10
+
11
+ # ----------------------------
12
+ # Load model once
13
+ # ----------------------------
14
+ def load_segmentation_model():
15
+ global model
16
+ if model is None:
17
+ model = load_model(MODEL_PATH, compile=False)
18
+
19
+ # ----------------------------
20
+ # Decode Landsat band from base64
21
+ # ----------------------------
22
+ def decode_band_float32(b64):
23
+ """Decode base64-encoded float32 data to 2D array"""
24
+ raw = base64.b64decode(b64)
25
+ arr = np.frombuffer(raw, dtype=np.float32)
26
+ side = int(np.sqrt(arr.size)) # assumes square tile
27
+ arr = arr.reshape((side, side))
28
+ return arr
29
+
30
+ # ----------------------------
31
+ # Spectral Indices (matching training pipeline)
32
+ # ----------------------------
33
+ def ndvi(red, nir):
34
+ """Normalized Difference Vegetation Index"""
35
+ return (nir - red) / (nir + red + EPS)
36
+
37
+ def ndwi(green, nir):
38
+ """Normalized Difference Water Index"""
39
+ return (green - nir) / (green + nir + EPS)
40
+
41
+ def nbr(nir, swir2):
42
+ """Normalized Burn Ratio"""
43
+ return (nir - swir2) / (nir + swir2 + EPS)
44
+
45
+ # ----------------------------
46
+ # Build 9-channel tensor from Landsat 8
47
+ # ----------------------------
48
+ def build_input_tensor(bands):
49
+ """
50
+ Build 9-channel input tensor from Landsat 8 Collection 2 Level 2 data
51
+
52
+ Args:
53
+ bands: Dictionary with keys:
54
+ - Blue, Green, Red: Optical bands (0-2)
55
+ - NIR, SWIR1, SWIR2: Infrared bands (3-5)
56
+ - NDVI, NDWI, NBR: Indices (6-8)
57
+
58
+ Values can be:
59
+ - Base64-encoded float32 strings (from API)
60
+ - Numpy arrays (from direct processing)
61
+
62
+ Returns:
63
+ (1, H, W, 9) array ready for model inference
64
+
65
+ Expected value range:
66
+ - Optical bands: [-0.2, 0.6]
67
+ - Indices: [-1, 1]
68
+ """
69
+ # Extract and decode optical bands
70
+ blue = decode_band_float32(bands["Blue"]) if isinstance(bands["Blue"], str) else bands["Blue"]
71
+ green = decode_band_float32(bands["Green"]) if isinstance(bands["Green"], str) else bands["Green"]
72
+ red = decode_band_float32(bands["Red"]) if isinstance(bands["Red"], str) else bands["Red"]
73
+ nir = decode_band_float32(bands["NIR"]) if isinstance(bands["NIR"], str) else bands["NIR"]
74
+ swir1 = decode_band_float32(bands["SWIR1"]) if isinstance(bands["SWIR1"], str) else bands["SWIR1"]
75
+ swir2 = decode_band_float32(bands["SWIR2"]) if isinstance(bands["SWIR2"], str) else bands["SWIR2"]
76
+
77
+ # Use pre-calculated indices if provided, otherwise compute them
78
+ if "NDVI" in bands and bands["NDVI"] is not None:
79
+ ndvi_map = decode_band_float32(bands["NDVI"]) if isinstance(bands["NDVI"], str) else bands["NDVI"]
80
+ else:
81
+ ndvi_map = ndvi(red, nir)
82
+
83
+ if "NDWI" in bands and bands["NDWI"] is not None:
84
+ ndwi_map = decode_band_float32(bands["NDWI"]) if isinstance(bands["NDWI"], str) else bands["NDWI"]
85
+ else:
86
+ ndwi_map = ndwi(green, nir)
87
+
88
+ if "NBR" in bands and bands["NBR"] is not None:
89
+ nbr_map = decode_band_float32(bands["NBR"]) if isinstance(bands["NBR"], str) else bands["NBR"]
90
+ else:
91
+ nbr_map = nbr(nir, swir2)
92
+
93
+ # Stack into (H, W, 9) - matches training data format exactly
94
+ stacked = np.stack([
95
+ blue,
96
+ green,
97
+ red,
98
+ nir,
99
+ swir1,
100
+ swir2,
101
+ ndvi_map,
102
+ ndwi_map,
103
+ nbr_map
104
+ ], axis=-1)
105
+
106
+ stacked = stacked.astype(np.float32)
107
+ stacked = np.expand_dims(stacked, axis=0) # (1, H, W, 9)
108
+
109
+ return stacked
110
+
111
+ # ----------------------------
112
+ # Inference
113
+ # ----------------------------
114
+ def predict_segmentation(bands):
115
+ """
116
+ Predict forest segmentation mask
117
+
118
+ Args:
119
+ bands: Dictionary with Landsat 8 bands
120
+
121
+ Returns:
122
+ Dictionary with:
123
+ - mask: (H, W) binary segmentation
124
+ - forest_percentage: % of pixels classified as forest
125
+ - forest_confidence: average confidence on forest pixels
126
+ - metadata: model and input information
127
+ """
128
+ load_segmentation_model()
129
+
130
+ x = build_input_tensor(bands)
131
+ pred = model.predict(x, verbose=0)[0, :, :, 0]
132
+
133
+ # Generate binary mask
134
+ mask = (pred > 0.5).astype(np.uint8) * 255
135
+
136
+ # Calculate statistics
137
+ forest_confidence = float(np.mean(pred[pred > 0.5])) if np.any(pred > 0.5) else 0.0
138
+ forest_percentage = float((pred > 0.5).sum() / pred.size * 100)
139
+
140
+ return {
141
+ "mask": mask.tolist(),
142
+ "forest_percentage": forest_percentage,
143
+ "forest_confidence": forest_confidence,
144
+ "mean_prediction": float(pred.mean()),
145
+ "classes": ["forest", "non-forest"],
146
+ "model_info": {
147
+ "training_data": "Landsat 8 Collection 2 Level 2",
148
+ "bands": ["Blue", "Green", "Red", "NIR", "SWIR1", "SWIR2", "NDVI", "NDWI", "NBR"],
149
+ "patch_size": 256,
150
+ "value_range": "[-0.2, 0.6] for optical, [-1, 1] for indices"
151
+ }
152
+ }
models/Forest_Segmentation_Best.keras ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:248cd02fc2a59e7f82e1a74b9593779d557767a8b488c6b0c0caefd416a87453
3
+ size 520724556
requirements.txt ADDED
Binary file (1.87 kB). View file
 
schemas.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+ from typing import Dict, List, Optional, Union
3
+
4
+ class PredictRequest(BaseModel):
5
+ """
6
+ Forest segmentation prediction request
7
+
8
+ Supports Landsat 8 Collection 2 Level 2 data format.
9
+ Each band can be provided as:
10
+ - Base64-encoded float32 data (for remote API calls)
11
+ - Array/list format (for direct server calls)
12
+
13
+ Required bands:
14
+ - Blue, Green, Red: Optical bands
15
+ - NIR, SWIR1, SWIR2: Infrared bands
16
+ - NDVI, NDWI, NBR: Spectral indices (or server will compute them)
17
+
18
+ Optional special keys:
19
+ - _invert_mask: Set to true to invert forest/non-forest in response
20
+
21
+ Value range expectations:
22
+ - Optical bands: [-0.2, 0.6]
23
+ - Indices: [-1, 1]
24
+ """
25
+ model_name: str = "forest_segmentation"
26
+ model_version: str = "landsat8_v1"
27
+ bands: Dict[str, Union[str, List, int]] # Band data as base64 or array
28
+
29
+
30
+ from typing import List, Dict, Any, Optional
31
+ from pydantic import BaseModel
32
+
33
+ class PredictResponse(BaseModel):
34
+ mask: List[int] # 1D flat list ✓
35
+ inverted_mask: List[int] # 1D flat list ✓
36
+ forest_percentage: float
37
+ forest_confidence: float
38
+ mean_prediction: float
39
+ classes: Dict[str, int]
40
+ model_info: Dict[str, Any]
41
+ debug: Optional[Dict[str, Any]] = None