CheXNetDeep / app.py
thehammadishaq's picture
updated Dicom
3aab522
from fastapi import FastAPI, UploadFile, File, HTTPException, Request, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from slowapi import Limiter
from slowapi.util import get_remote_address
import tensorflow as tf
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense
from tensorflow.keras.applications import DenseNet121
from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras.applications.densenet import preprocess_input
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import cv2
import io
import uuid
from datetime import datetime, timedelta
import base64
import pydicom
import os
# Configuration
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
PORT = 7860
app = FastAPI(
title="ChexNet Medical Imaging API",
description="API for chest X-ray analysis with Grad-CAM visualization",
version="5.0.0"
)
# Rate limiter setup
limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
# CORS configuration
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Model configuration
layer_name = 'conv5_block16_concat'
class_names = [
'Cardiomegaly', 'Emphysema', 'Effusion', 'Hernia', 'Infiltration',
'Mass', 'Nodule', 'Atelectasis', 'Pneumothorax', 'Pleural_Thickening',
'Pneumonia', 'Fibrosis', 'Edema', 'Consolidation', 'No Finding'
]
def build_model():
base_model = DenseNet121(
weights=None,
include_top=False,
input_shape=(None, None, 3)
)
x = base_model.output
x = GlobalAveragePooling2D()(x)
predictions = Dense(14, activation='sigmoid')(x)
return Model(inputs=base_model.input, outputs=predictions)
def load_model_with_fallback():
try:
model = build_model()
model.load_weights('pretrained_model.h5')
return model
except Exception as e:
print(f"Primary loading failed: {e}")
try:
model = load_model('Densenet.h5', compile=False)
return model
except Exception as e:
print(f"Fallback loading failed: {e}")
raise RuntimeError("All model loading strategies failed")
# Load model
try:
model = load_model_with_fallback()
print("✅ Model loaded successfully!")
except Exception as e:
print(f"❌ Model loading failed: {e}")
raise
def generate_gradcam(img):
img_array = img_to_array(img)
img_array = np.expand_dims(img_array, axis=0)
img_array = preprocess_input(img_array)
grad_model = Model(
inputs=model.inputs,
outputs=[model.get_layer(layer_name).output, model.output]
)
with tf.GradientTape() as tape:
conv_outputs, predictions = grad_model(img_array)
class_idx = tf.argmax(predictions[0])
output = conv_outputs[0]
grads = tape.gradient(predictions, conv_outputs)[0]
guided_grads = tf.cast(output > 0, 'float32') * tf.cast(grads > 0, 'float32') * grads
weights = tf.reduce_mean(guided_grads, axis=(0, 1))
cam = tf.reduce_sum(tf.multiply(weights, output), axis=-1)
heatmap = np.maximum(cam, 0)
heatmap /= np.max(heatmap)
heatmap_img = plt.cm.jet(heatmap)[..., :3]
original_img = Image.fromarray(img)
heatmap_img = Image.fromarray((heatmap_img * 255).astype(np.uint8))
heatmap_img = heatmap_img.resize(original_img.size)
return Image.blend(original_img, heatmap_img, 0.5)
def process_predictions(predictions):
decoded = []
for pred in predictions:
top_indices = np.argsort(pred)[::-1][:len(class_names)]
decoded.append([(class_names[i], float(pred[i])) for i in top_indices])
return decoded
def dump_file_sample(file_bytes, filename="debug_file_sample.bin"):
"""Save a sample of file bytes for debugging"""
try:
sample_size = min(512, len(file_bytes))
with open(filename, "wb") as f:
f.write(file_bytes[:sample_size])
print(f"Saved {sample_size} bytes sample to {filename}")
# Try to print first few bytes as hex
hex_sample = ' '.join([f'{b:02x}' for b in file_bytes[:16]])
print(f"First 16 bytes: {hex_sample}")
except Exception as e:
print(f"Failed to save debug sample: {e}")
def preprocess_dicom(file_bytes):
"""Process DICOM format images for the model with robust error handling."""
# Create unique temporary filenames to avoid conflicts
import tempfile
temp_dir = tempfile.gettempdir()
uid = str(uuid.uuid4())[:8]
temp_file = os.path.join(temp_dir, f"temp_dicom_{uid}.dcm")
temp_img_file = os.path.join(temp_dir, f"temp_dicom_img_{uid}.png")
try:
print(f"Processing DICOM file of size {len(file_bytes)} bytes")
# Write bytes to temporary file
with open(temp_file, "wb") as f:
f.write(file_bytes)
# Read the DICOM file with force=True to ignore errors
try:
# Use defer_size=True to avoid reading large data elements
# until explicitly accessed
dicom_data = pydicom.dcmread(temp_file, force=True, defer_size=None)
# Check transfer syntax
if hasattr(dicom_data, 'file_meta') and hasattr(dicom_data.file_meta, 'TransferSyntaxUID'):
ts_uid = str(dicom_data.file_meta.TransferSyntaxUID)
print(f"DICOM file read successfully. Transfer syntax: {ts_uid}")
else:
print("DICOM file read but no transfer syntax found - assuming default Implicit VR Little Endian")
except Exception as e:
print(f"Error reading DICOM file: {e}")
raise ValueError(f"Failed to read DICOM file: {e}")
# Verify pixel data exists
if not hasattr(dicom_data, 'PixelData'):
print("PixelData attribute missing")
# Try to check for alternate pixel data representations
alt_pixel_attrs = ['FloatPixelData', 'DoubleFloatPixelData']
has_pixel_data = False
for attr in alt_pixel_attrs:
if hasattr(dicom_data, attr):
has_pixel_data = True
print(f"Found alternate pixel data: {attr}")
break
if not has_pixel_data:
raise ValueError("DICOM file does not contain any pixel data")
# Print DICOM image properties for diagnosis
print(f"DICOM properties:")
for attr in ['BitsAllocated', 'BitsStored', 'HighBit', 'SamplesPerPixel', 'Rows', 'Columns']:
if hasattr(dicom_data, attr):
print(f" {attr}: {getattr(dicom_data, attr)}")
else:
print(f" {attr}: Not specified")
# Algorithm to try multiple methods to extract pixel data
img = None
methods_tried = []
# Method 1: Direct pixel_array access with exception handling
if img is None:
try:
methods_tried.append("Direct pixel_array")
img = dicom_data.pixel_array
if img.size > 0:
print(f"Successfully extracted pixel data via pixel_array: shape={img.shape}, dtype={img.dtype}")
else:
img = None
raise ValueError("Extracted pixel array is empty")
except Exception as e:
print(f"Method 1 (direct pixel_array) failed: {e}")
img = None
# Method 2: Save and reload through PNG for compressed images
if img is None:
try:
methods_tried.append("PNG intermediate")
print("Trying PNG intermediate method...")
dicom_data.save_as(temp_img_file)
# Try with IMREAD_UNCHANGED first to preserve bit depth
img = cv2.imread(temp_img_file, cv2.IMREAD_UNCHANGED)
if img is None or img.size == 0:
# Fall back to IMREAD_GRAYSCALE
img = cv2.imread(temp_img_file, cv2.IMREAD_GRAYSCALE)
if img is not None and img.size > 0:
print(f"Successfully extracted pixel data via PNG: shape={img.shape}, dtype={img.dtype}")
else:
img = None
raise ValueError("PNG conversion resulted in empty image")
except Exception as e:
print(f"Method 2 (PNG intermediate) failed: {e}")
img = None
# Method 3: PIL intermediate
if img is None:
try:
methods_tried.append("PIL intermediate")
print("Trying PIL intermediate method...")
from PIL import Image
dicom_data.save_as(temp_img_file)
pil_img = Image.open(temp_img_file)
img = np.array(pil_img)
if img is not None and img.size > 0:
print(f"Successfully extracted pixel data via PIL: shape={img.shape}, dtype={img.dtype}")
else:
img = None
raise ValueError("PIL conversion resulted in empty image")
except Exception as e:
print(f"Method 3 (PIL intermediate) failed: {e}")
img = None
# If all methods failed, create a diagnostic image
if img is None:
print(f"All pixel data extraction methods failed: {', '.join(methods_tried)}")
# Create a diagnostic image
img = np.ones((540, 540), dtype=np.uint8) * 128
# Add text about the error
img_with_text = np.ones((540, 540, 3), dtype=np.uint8) * 128
error_text = "Failed to extract DICOM pixel data"
cv2.putText(img_with_text, error_text, (50, 270), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
# Return the diagnostic image
print("Returning diagnostic image due to extraction failure")
return img_with_text
# DICOM images are often 16-bit or higher, normalize to 8-bit for visualization
print(f"Original image: shape={img.shape}, dtype={img.dtype}, min={np.min(img)}, max={np.max(img)}")
# 1. Normalize pixel values to 8-bit range
if img.dtype != np.uint8:
try:
# Calculate data range for proper normalization
img_min = float(np.min(img))
img_max = float(np.max(img))
# Only normalize if we have a non-zero range
if img_max > img_min:
# Convert to float32 first for better precision
img = img.astype(np.float32)
# Scale to range [0, 255]
img = 255.0 * (img - img_min) / (img_max - img_min)
# Convert to uint8
img = img.astype(np.uint8)
print(f"Normalized to 8-bit: new range=[{np.min(img)}, {np.max(img)}]")
else:
# Handle uniform pixel values
img = np.full(img.shape, 128, dtype=np.uint8)
print("Image has uniform pixel values, using mid-gray")
except Exception as e:
print(f"Error during normalization: {e}")
# Create a valid grayscale image in case of error
img = np.full(img.shape if len(img.shape) >= 2 else (540, 540), 128, dtype=np.uint8)
# 2. Handle color conversion based on image dimensions
try:
# Check image dimensions
if len(img.shape) == 2:
# Single channel (grayscale) image - convert to 3-channel
print("Converting grayscale to RGB using manual conversion")
h, w = img.shape
rgb_img = np.zeros((h, w, 3), dtype=np.uint8)
rgb_img[:, :, 0] = img # R
rgb_img[:, :, 1] = img # G
rgb_img[:, :, 2] = img # B
img = rgb_img
elif len(img.shape) == 3:
if img.shape[2] == 1:
# Single channel image in 3D array
print("Converting single-channel 3D array to RGB")
h, w, _ = img.shape
img_2d = img.reshape(h, w)
rgb_img = np.zeros((h, w, 3), dtype=np.uint8)
rgb_img[:, :, 0] = img_2d
rgb_img[:, :, 1] = img_2d
rgb_img[:, :, 2] = img_2d
img = rgb_img
elif img.shape[2] == 3:
# Already RGB, make sure it's the right color space
print("Image already has 3 channels, ensuring RGB color space")
# No conversion needed if already RGB
elif img.shape[2] == 4:
# RGBA image - remove alpha channel
print("Converting RGBA to RGB by removing alpha channel")
img = img[:, :, :3]
else:
# Unusual number of channels, convert to grayscale then RGB
print(f"Unusual channel count ({img.shape[2]}), converting to grayscale then RGB")
if np.max(img) > 0: # Avoid division by zero
# Average across channels and normalize
gray = np.mean(img, axis=2).astype(np.uint8)
h, w = gray.shape
rgb_img = np.zeros((h, w, 3), dtype=np.uint8)
rgb_img[:, :, 0] = gray
rgb_img[:, :, 1] = gray
rgb_img[:, :, 2] = gray
img = rgb_img
else:
# Create a valid RGB image if all pixels are zero
h, w = img.shape[:2]
img = np.full((h, w, 3), 128, dtype=np.uint8)
else:
# Invalid dimensions, create fallback image
print(f"Invalid image dimensions: {img.shape}")
img = np.full((540, 540, 3), 128, dtype=np.uint8)
cv2.putText(img, "Invalid image dimensions", (50, 270),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
except Exception as e:
print(f"Error during color conversion: {e}")
# Create a valid RGB image in case of error
img = np.full((540, 540, 3), 128, dtype=np.uint8)
cv2.putText(img, f"Error: {str(e)[:30]}", (50, 270),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
# 3. Add final validation and cleanup
print(f"After color conversion: shape={img.shape}, dtype={img.dtype}")
# Final validation
if img is None or img.size == 0 or len(img.shape) < 2:
raise ValueError("Image processing resulted in invalid image")
# Resize for model input
print(f"Final image shape before resize: {img.shape}")
img = cv2.resize(img, (540, 540), interpolation=cv2.INTER_AREA)
print(f"Resized image shape: {img.shape}")
return img
except Exception as e:
print(f"DICOM processing failed: {e}")
raise
finally:
# Clean up temporary files
for temp_file_path in [temp_file, temp_img_file]:
if os.path.exists(temp_file_path):
try:
os.remove(temp_file_path)
except Exception as e:
print(f"Failed to remove temporary file {temp_file_path}: {e}")
def preprocess_image(file_bytes, content_type=None):
"""Process images for the model, handling both DICOM and standard formats."""
print(f"Preprocessing image with content type: {content_type}, size: {len(file_bytes)} bytes")
# Save a debug sample of the file bytes
dump_file_sample(file_bytes)
# Check if the file is a DICOM file
is_likely_dicom = False
# Check content type for DICOM indicators
if content_type and ('dicom' in content_type.lower() or
content_type.lower() == 'application/octet-stream' or
content_type.lower() == 'application/dicom'):
is_likely_dicom = True
# Also check file signature (DICOM files usually start with "DICM" at byte offset 128)
if len(file_bytes) > 132:
dicom_signature = file_bytes[128:132]
if dicom_signature == b'DICM':
is_likely_dicom = True
print("DICOM signature detected in file")
if is_likely_dicom:
try:
return preprocess_dicom(file_bytes)
except Exception as e:
print(f"DICOM processing error: {e}")
# Fall back to standard image processing if DICOM processing fails
print("Falling back to standard image processing")
# Process as standard image format
try:
print("Processing as standard image format")
img = cv2.imdecode(np.frombuffer(file_bytes, np.uint8), cv2.IMREAD_COLOR)
# Validate image was successfully decoded
if img is None or img.size == 0:
print("Standard image decoding failed - creating fallback image")
# Create a fallback image for debugging
img = np.ones((540, 540, 3), dtype=np.uint8) * 128
# Add diagnostic pattern
cv2.line(img, (0, 0), (540, 540), (200, 100, 100), 10)
cv2.line(img, (540, 0), (0, 540), (100, 200, 100), 10)
return img
# If we got a valid image, proceed with color conversion
print(f"Standard image decoded successfully: shape={img.shape}, dtype={img.dtype}")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return cv2.resize(img, (540, 540), interpolation=cv2.INTER_AREA)
except Exception as e:
print(f"Standard image processing error: {e}")
# Create fallback image as last resort
img = np.ones((540, 540, 3), dtype=np.uint8) * 128
cv2.putText(img, "Error: " + str(e)[:30], (50, 270), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
return img
@app.post("/analyze")
@limiter.limit("5/minute")
async def analyze_image(
request: Request,
file: UploadFile = File(...)
):
# Accept both standard image formats and DICOM files
if not (file.content_type.startswith('image/') or
'dicom' in file.content_type.lower() or
file.content_type == 'application/octet-stream'):
raise HTTPException(400, "Only image or DICOM files accepted")
if file.size > MAX_FILE_SIZE:
raise HTTPException(413, f"File too large (max {MAX_FILE_SIZE//1024//1024}MB)")
try:
contents = await file.read()
img = preprocess_image(contents, file.content_type)
img_array = img_to_array(img)
img_array = np.expand_dims(img_array, axis=0)
img_array = preprocess_input(img_array)
predictions = model.predict(img_array)
decoded = process_predictions(predictions)
heatmap = generate_gradcam(img)
# Convert heatmap to base64 instead of saving to file
img_byte_arr = io.BytesIO()
heatmap.save(img_byte_arr, format='PNG')
heatmap_base64 = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8')
return {
"predictions": decoded[0],
"heatmap_image": heatmap_base64,
"heatmap_format": "base64 encoded PNG"
}
except Exception as e:
error_message = str(e)
print(f"Analysis failed with error: {error_message}")
# Return a more detailed error message
if "empty()" in error_message and "cvtColor" in error_message:
raise HTTPException(
status_code=500,
detail=f"Failed to process image: The image data is empty or corrupt. Please check your DICOM file format. Original error: {error_message}"
)
elif "DICOM" in error_message:
raise HTTPException(
status_code=422,
detail=f"DICOM processing error: {error_message}. Please ensure your DICOM file contains valid pixel data."
)
else:
raise HTTPException(500, f"Analysis failed: {error_message}")
@app.get("/health")
async def health_check():
return {
"status": "healthy",
"timestamp": datetime.now().isoformat(),
"features": {
"dicom_support": True
}
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=PORT)