import numpy as np import triton_python_backend_utils as pb_utils from omnicloudmask import predict_from_array import rasterio from rasterio.io import MemoryFile from rasterio.enums import Resampling import tempfile import os from io import BytesIO class TritonPythonModel: def initialize(self, args): """ Initialize the model. This function is called once when the model is loaded. """ print('Initialized Cloud Detection model with JP2 input and robust GDAL handling') def safe_read_jp2_bytes(self, jp2_bytes): """ Safely read JP2 bytes with multiple fallback methods """ try: # Method 1: Try direct MemoryFile approach (works if GDAL drivers are properly configured) with MemoryFile(jp2_bytes) as memfile: with memfile.open() as src: data = src.read(1).astype(np.float32) height, width = src.height, src.width profile = src.profile return data, height, width, profile except Exception as e1: print(f"Method 1 (MemoryFile) failed: {e1}") try: # Method 2: Write to temporary file and read from disk with tempfile.NamedTemporaryFile(delete=False, suffix='.jp2') as tmp_file: tmp_file.write(jp2_bytes) tmp_file.flush() with rasterio.open(tmp_file.name) as src: data = src.read(1).astype(np.float32) height, width = src.height, src.width profile = src.profile # Clean up temporary file os.unlink(tmp_file.name) return data, height, width, profile except Exception as e2: print(f"Method 2 (temporary file) failed: {e2}") try: # Method 3: Try with different suffix and basic profile with tempfile.NamedTemporaryFile(delete=False, suffix='.tiff') as tmp_file: tmp_file.write(jp2_bytes) tmp_file.flush() with rasterio.open(tmp_file.name) as src: data = src.read(1).astype(np.float32) height, width = src.height, src.width profile = {'driver': 'GTiff', 'height': height, 'width': width, 'count': 1, 'dtype': 'float32'} os.unlink(tmp_file.name) return data, height, width, profile except Exception as e3: print(f"Method 3 (tiff fallback) failed: {e3}") # Method 4: Final fallback - try to interpret as raw numpy array try: # This assumes the client is sending raw numpy bytes as fallback data_array = np.frombuffer(jp2_bytes, dtype=np.float32) # Try to guess square dimensions side_length = int(np.sqrt(len(data_array))) if side_length * side_length == len(data_array): data = data_array.reshape(side_length, side_length) height, width = side_length, side_length profile = {'driver': 'GTiff', 'height': height, 'width': width, 'count': 1, 'dtype': 'float32'} return data, height, width, profile else: # Try common satellite image dimensions common_dims = [(10980, 10980), (5490, 5490), (1024, 1024), (512, 512)] for h, w in common_dims: if h * w == len(data_array): data = data_array.reshape(h, w) height, width = h, w profile = {'driver': 'GTiff', 'height': height, 'width': width, 'count': 1, 'dtype': 'float32'} return data, height, width, profile raise ValueError(f"Cannot interpret data array of length {len(data_array)} as image") except Exception as e4: raise Exception(f"All fallback methods failed: MemoryFile({e1}), TempFile({e2}), TiffFallback({e3}), RawBytes({e4})") def safe_resample_data(self, data, current_height, current_width, target_height, target_width, profile): """ Safely resample data to target dimensions with fallback methods """ if current_height == target_height and current_width == target_width: return data try: # Method 1: Use rasterio resampling temp_profile = profile.copy() temp_profile.update({ 'height': current_height, 'width': current_width, 'count': 1, 'dtype': 'float32' }) with MemoryFile() as memfile: with memfile.open(**temp_profile) as temp_dataset: temp_dataset.write(data, 1) resampled = temp_dataset.read( out_shape=(1, target_height, target_width), resampling=Resampling.bilinear )[0].astype(np.float32) return resampled except Exception as e1: print(f"Rasterio resampling failed: {e1}") try: # Method 2: Use scipy if available from scipy import ndimage zoom_factors = (target_height / current_height, target_width / current_width) resampled = ndimage.zoom(data, zoom_factors, order=1) return resampled.astype(np.float32) except ImportError: print("Scipy not available for resampling") # Method 3: Simple nearest-neighbor resampling h_indices = np.round(np.linspace(0, current_height - 1, target_height)).astype(int) w_indices = np.round(np.linspace(0, current_width - 1, target_width)).astype(int) resampled = data[np.ix_(h_indices, w_indices)] return resampled.astype(np.float32) except Exception as e2: print(f"Scipy resampling failed: {e2}") # Method 3: Simple nearest-neighbor resampling h_indices = np.round(np.linspace(0, current_height - 1, target_height)).astype(int) w_indices = np.round(np.linspace(0, current_width - 1, target_width)).astype(int) resampled = data[np.ix_(h_indices, w_indices)] return resampled.astype(np.float32) def execute(self, requests): """ Process inference requests with robust error handling. """ responses = [] for request in requests: try: input_tensor = pb_utils.get_input_tensor_by_name(request, "input_jp2_bytes") jp2_bytes_list = input_tensor.as_numpy() if len(jp2_bytes_list) != 3: error_msg = f"Expected 3 JP2 byte strings, received {len(jp2_bytes_list)}" error = pb_utils.TritonError(error_msg) response = pb_utils.InferenceResponse(output_tensors=[], error=error) responses.append(response) continue red_bytes = jp2_bytes_list[0] green_bytes = jp2_bytes_list[1] nir_bytes = jp2_bytes_list[2] print(f"Processing JP2 data - sizes: Red={len(red_bytes)}, Green={len(green_bytes)}, NIR={len(nir_bytes)}") # Read red band data (use as reference for dimensions) red_data, target_height, target_width, red_profile = self.safe_read_jp2_bytes(red_bytes) print(f"Red band: {red_data.shape}, target dimensions: {target_height}x{target_width}") # Read and resample green band green_data, green_height, green_width, green_profile = self.safe_read_jp2_bytes(green_bytes) green_data = self.safe_resample_data(green_data, green_height, green_width, target_height, target_width, green_profile) print(f"Green band after resampling: {green_data.shape}") # Read and resample NIR band nir_data, nir_height, nir_width, nir_profile = self.safe_read_jp2_bytes(nir_bytes) nir_data = self.safe_resample_data(nir_data, nir_height, nir_width, target_height, target_width, nir_profile) print(f"NIR band after resampling: {nir_data.shape}") # Verify all bands have the same shape if not (red_data.shape == green_data.shape == nir_data.shape): shapes = [red_data.shape, green_data.shape, nir_data.shape] error_msg = f"Band shape mismatch after resampling: {shapes}" error = pb_utils.TritonError(error_msg) response = pb_utils.InferenceResponse(output_tensors=[], error=error) responses.append(response) continue # Stack bands in CHW format for prediction (channels, height, width) prediction_array = np.stack([red_data, green_data, nir_data], axis=0) print(f"Final prediction array shape: {prediction_array.shape}") # Run cloud detection prediction cloud_mask = predict_from_array(prediction_array) print(f"Cloud mask shape: {cloud_mask.shape}") # Flatten the mask for output if cloud_mask.ndim > 1: cloud_mask = cloud_mask.flatten() # Create output tensor (config expects TYPE_UINT8) output_tensor = pb_utils.Tensor("output_mask", cloud_mask.astype(np.uint8)) response = pb_utils.InferenceResponse(output_tensors=[output_tensor]) responses.append(response) except Exception as e: # Enhanced error reporting error_msg = f"Error processing JP2 data: {str(e)}" print(f"Model execution error: {error_msg}") error = pb_utils.TritonError(error_msg) response = pb_utils.InferenceResponse(output_tensors=[], error=error) responses.append(response) return responses def finalize(self): """ Clean up when the model is unloaded. """ print('Cloud Detection model finalized')