from fastapi import FastAPI, UploadFile, File, HTTPException, Request, BackgroundTasks from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from slowapi import Limiter from slowapi.util import get_remote_address import tensorflow as tf from tensorflow.keras.models import Model, load_model from tensorflow.keras.layers import GlobalAveragePooling2D, Dense from tensorflow.keras.applications import DenseNet121 from tensorflow.keras.preprocessing.image import img_to_array from tensorflow.keras.applications.densenet import preprocess_input import numpy as np from PIL import Image import matplotlib.pyplot as plt import cv2 import io import uuid from datetime import datetime, timedelta import base64 import pydicom import os # Configuration MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB PORT = 7860 app = FastAPI( title="ChexNet Medical Imaging API", description="API for chest X-ray analysis with Grad-CAM visualization", version="5.0.0" ) # Rate limiter setup limiter = Limiter(key_func=get_remote_address) app.state.limiter = limiter # CORS configuration app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Model configuration layer_name = 'conv5_block16_concat' class_names = [ 'Cardiomegaly', 'Emphysema', 'Effusion', 'Hernia', 'Infiltration', 'Mass', 'Nodule', 'Atelectasis', 'Pneumothorax', 'Pleural_Thickening', 'Pneumonia', 'Fibrosis', 'Edema', 'Consolidation', 'No Finding' ] def build_model(): base_model = DenseNet121( weights=None, include_top=False, input_shape=(None, None, 3) ) x = base_model.output x = GlobalAveragePooling2D()(x) predictions = Dense(14, activation='sigmoid')(x) return Model(inputs=base_model.input, outputs=predictions) def load_model_with_fallback(): try: model = build_model() model.load_weights('pretrained_model.h5') return model except Exception as e: print(f"Primary loading failed: {e}") try: model = load_model('Densenet.h5', compile=False) return model except Exception as e: print(f"Fallback loading failed: {e}") raise RuntimeError("All model loading strategies failed") # Load model try: model = load_model_with_fallback() print("✅ Model loaded successfully!") except Exception as e: print(f"❌ Model loading failed: {e}") raise def generate_gradcam(img): img_array = img_to_array(img) img_array = np.expand_dims(img_array, axis=0) img_array = preprocess_input(img_array) grad_model = Model( inputs=model.inputs, outputs=[model.get_layer(layer_name).output, model.output] ) with tf.GradientTape() as tape: conv_outputs, predictions = grad_model(img_array) class_idx = tf.argmax(predictions[0]) output = conv_outputs[0] grads = tape.gradient(predictions, conv_outputs)[0] guided_grads = tf.cast(output > 0, 'float32') * tf.cast(grads > 0, 'float32') * grads weights = tf.reduce_mean(guided_grads, axis=(0, 1)) cam = tf.reduce_sum(tf.multiply(weights, output), axis=-1) heatmap = np.maximum(cam, 0) heatmap /= np.max(heatmap) heatmap_img = plt.cm.jet(heatmap)[..., :3] original_img = Image.fromarray(img) heatmap_img = Image.fromarray((heatmap_img * 255).astype(np.uint8)) heatmap_img = heatmap_img.resize(original_img.size) return Image.blend(original_img, heatmap_img, 0.5) def process_predictions(predictions): decoded = [] for pred in predictions: top_indices = np.argsort(pred)[::-1][:len(class_names)] decoded.append([(class_names[i], float(pred[i])) for i in top_indices]) return decoded def dump_file_sample(file_bytes, filename="debug_file_sample.bin"): """Save a sample of file bytes for debugging""" try: sample_size = min(512, len(file_bytes)) with open(filename, "wb") as f: f.write(file_bytes[:sample_size]) print(f"Saved {sample_size} bytes sample to {filename}") # Try to print first few bytes as hex hex_sample = ' '.join([f'{b:02x}' for b in file_bytes[:16]]) print(f"First 16 bytes: {hex_sample}") except Exception as e: print(f"Failed to save debug sample: {e}") def preprocess_dicom(file_bytes): """Process DICOM format images for the model with robust error handling.""" # Create unique temporary filenames to avoid conflicts import tempfile temp_dir = tempfile.gettempdir() uid = str(uuid.uuid4())[:8] temp_file = os.path.join(temp_dir, f"temp_dicom_{uid}.dcm") temp_img_file = os.path.join(temp_dir, f"temp_dicom_img_{uid}.png") try: print(f"Processing DICOM file of size {len(file_bytes)} bytes") # Write bytes to temporary file with open(temp_file, "wb") as f: f.write(file_bytes) # Read the DICOM file with force=True to ignore errors try: # Use defer_size=True to avoid reading large data elements # until explicitly accessed dicom_data = pydicom.dcmread(temp_file, force=True, defer_size=None) # Check transfer syntax if hasattr(dicom_data, 'file_meta') and hasattr(dicom_data.file_meta, 'TransferSyntaxUID'): ts_uid = str(dicom_data.file_meta.TransferSyntaxUID) print(f"DICOM file read successfully. Transfer syntax: {ts_uid}") else: print("DICOM file read but no transfer syntax found - assuming default Implicit VR Little Endian") except Exception as e: print(f"Error reading DICOM file: {e}") raise ValueError(f"Failed to read DICOM file: {e}") # Verify pixel data exists if not hasattr(dicom_data, 'PixelData'): print("PixelData attribute missing") # Try to check for alternate pixel data representations alt_pixel_attrs = ['FloatPixelData', 'DoubleFloatPixelData'] has_pixel_data = False for attr in alt_pixel_attrs: if hasattr(dicom_data, attr): has_pixel_data = True print(f"Found alternate pixel data: {attr}") break if not has_pixel_data: raise ValueError("DICOM file does not contain any pixel data") # Print DICOM image properties for diagnosis print(f"DICOM properties:") for attr in ['BitsAllocated', 'BitsStored', 'HighBit', 'SamplesPerPixel', 'Rows', 'Columns']: if hasattr(dicom_data, attr): print(f" {attr}: {getattr(dicom_data, attr)}") else: print(f" {attr}: Not specified") # Algorithm to try multiple methods to extract pixel data img = None methods_tried = [] # Method 1: Direct pixel_array access with exception handling if img is None: try: methods_tried.append("Direct pixel_array") img = dicom_data.pixel_array if img.size > 0: print(f"Successfully extracted pixel data via pixel_array: shape={img.shape}, dtype={img.dtype}") else: img = None raise ValueError("Extracted pixel array is empty") except Exception as e: print(f"Method 1 (direct pixel_array) failed: {e}") img = None # Method 2: Save and reload through PNG for compressed images if img is None: try: methods_tried.append("PNG intermediate") print("Trying PNG intermediate method...") dicom_data.save_as(temp_img_file) # Try with IMREAD_UNCHANGED first to preserve bit depth img = cv2.imread(temp_img_file, cv2.IMREAD_UNCHANGED) if img is None or img.size == 0: # Fall back to IMREAD_GRAYSCALE img = cv2.imread(temp_img_file, cv2.IMREAD_GRAYSCALE) if img is not None and img.size > 0: print(f"Successfully extracted pixel data via PNG: shape={img.shape}, dtype={img.dtype}") else: img = None raise ValueError("PNG conversion resulted in empty image") except Exception as e: print(f"Method 2 (PNG intermediate) failed: {e}") img = None # Method 3: PIL intermediate if img is None: try: methods_tried.append("PIL intermediate") print("Trying PIL intermediate method...") from PIL import Image dicom_data.save_as(temp_img_file) pil_img = Image.open(temp_img_file) img = np.array(pil_img) if img is not None and img.size > 0: print(f"Successfully extracted pixel data via PIL: shape={img.shape}, dtype={img.dtype}") else: img = None raise ValueError("PIL conversion resulted in empty image") except Exception as e: print(f"Method 3 (PIL intermediate) failed: {e}") img = None # If all methods failed, create a diagnostic image if img is None: print(f"All pixel data extraction methods failed: {', '.join(methods_tried)}") # Create a diagnostic image img = np.ones((540, 540), dtype=np.uint8) * 128 # Add text about the error img_with_text = np.ones((540, 540, 3), dtype=np.uint8) * 128 error_text = "Failed to extract DICOM pixel data" cv2.putText(img_with_text, error_text, (50, 270), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2) # Return the diagnostic image print("Returning diagnostic image due to extraction failure") return img_with_text # DICOM images are often 16-bit or higher, normalize to 8-bit for visualization print(f"Original image: shape={img.shape}, dtype={img.dtype}, min={np.min(img)}, max={np.max(img)}") # 1. Normalize pixel values to 8-bit range if img.dtype != np.uint8: try: # Calculate data range for proper normalization img_min = float(np.min(img)) img_max = float(np.max(img)) # Only normalize if we have a non-zero range if img_max > img_min: # Convert to float32 first for better precision img = img.astype(np.float32) # Scale to range [0, 255] img = 255.0 * (img - img_min) / (img_max - img_min) # Convert to uint8 img = img.astype(np.uint8) print(f"Normalized to 8-bit: new range=[{np.min(img)}, {np.max(img)}]") else: # Handle uniform pixel values img = np.full(img.shape, 128, dtype=np.uint8) print("Image has uniform pixel values, using mid-gray") except Exception as e: print(f"Error during normalization: {e}") # Create a valid grayscale image in case of error img = np.full(img.shape if len(img.shape) >= 2 else (540, 540), 128, dtype=np.uint8) # 2. Handle color conversion based on image dimensions try: # Check image dimensions if len(img.shape) == 2: # Single channel (grayscale) image - convert to 3-channel print("Converting grayscale to RGB using manual conversion") h, w = img.shape rgb_img = np.zeros((h, w, 3), dtype=np.uint8) rgb_img[:, :, 0] = img # R rgb_img[:, :, 1] = img # G rgb_img[:, :, 2] = img # B img = rgb_img elif len(img.shape) == 3: if img.shape[2] == 1: # Single channel image in 3D array print("Converting single-channel 3D array to RGB") h, w, _ = img.shape img_2d = img.reshape(h, w) rgb_img = np.zeros((h, w, 3), dtype=np.uint8) rgb_img[:, :, 0] = img_2d rgb_img[:, :, 1] = img_2d rgb_img[:, :, 2] = img_2d img = rgb_img elif img.shape[2] == 3: # Already RGB, make sure it's the right color space print("Image already has 3 channels, ensuring RGB color space") # No conversion needed if already RGB elif img.shape[2] == 4: # RGBA image - remove alpha channel print("Converting RGBA to RGB by removing alpha channel") img = img[:, :, :3] else: # Unusual number of channels, convert to grayscale then RGB print(f"Unusual channel count ({img.shape[2]}), converting to grayscale then RGB") if np.max(img) > 0: # Avoid division by zero # Average across channels and normalize gray = np.mean(img, axis=2).astype(np.uint8) h, w = gray.shape rgb_img = np.zeros((h, w, 3), dtype=np.uint8) rgb_img[:, :, 0] = gray rgb_img[:, :, 1] = gray rgb_img[:, :, 2] = gray img = rgb_img else: # Create a valid RGB image if all pixels are zero h, w = img.shape[:2] img = np.full((h, w, 3), 128, dtype=np.uint8) else: # Invalid dimensions, create fallback image print(f"Invalid image dimensions: {img.shape}") img = np.full((540, 540, 3), 128, dtype=np.uint8) cv2.putText(img, "Invalid image dimensions", (50, 270), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2) except Exception as e: print(f"Error during color conversion: {e}") # Create a valid RGB image in case of error img = np.full((540, 540, 3), 128, dtype=np.uint8) cv2.putText(img, f"Error: {str(e)[:30]}", (50, 270), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2) # 3. Add final validation and cleanup print(f"After color conversion: shape={img.shape}, dtype={img.dtype}") # Final validation if img is None or img.size == 0 or len(img.shape) < 2: raise ValueError("Image processing resulted in invalid image") # Resize for model input print(f"Final image shape before resize: {img.shape}") img = cv2.resize(img, (540, 540), interpolation=cv2.INTER_AREA) print(f"Resized image shape: {img.shape}") return img except Exception as e: print(f"DICOM processing failed: {e}") raise finally: # Clean up temporary files for temp_file_path in [temp_file, temp_img_file]: if os.path.exists(temp_file_path): try: os.remove(temp_file_path) except Exception as e: print(f"Failed to remove temporary file {temp_file_path}: {e}") def preprocess_image(file_bytes, content_type=None): """Process images for the model, handling both DICOM and standard formats.""" print(f"Preprocessing image with content type: {content_type}, size: {len(file_bytes)} bytes") # Save a debug sample of the file bytes dump_file_sample(file_bytes) # Check if the file is a DICOM file is_likely_dicom = False # Check content type for DICOM indicators if content_type and ('dicom' in content_type.lower() or content_type.lower() == 'application/octet-stream' or content_type.lower() == 'application/dicom'): is_likely_dicom = True # Also check file signature (DICOM files usually start with "DICM" at byte offset 128) if len(file_bytes) > 132: dicom_signature = file_bytes[128:132] if dicom_signature == b'DICM': is_likely_dicom = True print("DICOM signature detected in file") if is_likely_dicom: try: return preprocess_dicom(file_bytes) except Exception as e: print(f"DICOM processing error: {e}") # Fall back to standard image processing if DICOM processing fails print("Falling back to standard image processing") # Process as standard image format try: print("Processing as standard image format") img = cv2.imdecode(np.frombuffer(file_bytes, np.uint8), cv2.IMREAD_COLOR) # Validate image was successfully decoded if img is None or img.size == 0: print("Standard image decoding failed - creating fallback image") # Create a fallback image for debugging img = np.ones((540, 540, 3), dtype=np.uint8) * 128 # Add diagnostic pattern cv2.line(img, (0, 0), (540, 540), (200, 100, 100), 10) cv2.line(img, (540, 0), (0, 540), (100, 200, 100), 10) return img # If we got a valid image, proceed with color conversion print(f"Standard image decoded successfully: shape={img.shape}, dtype={img.dtype}") img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) return cv2.resize(img, (540, 540), interpolation=cv2.INTER_AREA) except Exception as e: print(f"Standard image processing error: {e}") # Create fallback image as last resort img = np.ones((540, 540, 3), dtype=np.uint8) * 128 cv2.putText(img, "Error: " + str(e)[:30], (50, 270), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2) return img @app.post("/analyze") @limiter.limit("5/minute") async def analyze_image( request: Request, file: UploadFile = File(...) ): # Accept both standard image formats and DICOM files if not (file.content_type.startswith('image/') or 'dicom' in file.content_type.lower() or file.content_type == 'application/octet-stream'): raise HTTPException(400, "Only image or DICOM files accepted") if file.size > MAX_FILE_SIZE: raise HTTPException(413, f"File too large (max {MAX_FILE_SIZE//1024//1024}MB)") try: contents = await file.read() img = preprocess_image(contents, file.content_type) img_array = img_to_array(img) img_array = np.expand_dims(img_array, axis=0) img_array = preprocess_input(img_array) predictions = model.predict(img_array) decoded = process_predictions(predictions) heatmap = generate_gradcam(img) # Convert heatmap to base64 instead of saving to file img_byte_arr = io.BytesIO() heatmap.save(img_byte_arr, format='PNG') heatmap_base64 = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8') return { "predictions": decoded[0], "heatmap_image": heatmap_base64, "heatmap_format": "base64 encoded PNG" } except Exception as e: error_message = str(e) print(f"Analysis failed with error: {error_message}") # Return a more detailed error message if "empty()" in error_message and "cvtColor" in error_message: raise HTTPException( status_code=500, detail=f"Failed to process image: The image data is empty or corrupt. Please check your DICOM file format. Original error: {error_message}" ) elif "DICOM" in error_message: raise HTTPException( status_code=422, detail=f"DICOM processing error: {error_message}. Please ensure your DICOM file contains valid pixel data." ) else: raise HTTPException(500, f"Analysis failed: {error_message}") @app.get("/health") async def health_check(): return { "status": "healthy", "timestamp": datetime.now().isoformat(), "features": { "dicom_support": True } } if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=PORT)