arittrabag commited on
Commit
c5f59af
·
verified ·
1 Parent(s): 3a7c85b

Upload 4 files

Browse files
Files changed (4) hide show
  1. Dockerfile +10 -0
  2. app.py +129 -0
  3. mri_detector.h5 +3 -0
  4. requirements.txt +7 -0
Dockerfile ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /code
4
+
5
+ COPY requirements.txt .
6
+ RUN pip install --no-cache-dir -r requirements.txt
7
+
8
+ COPY . .
9
+
10
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, HTTPException
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ import tensorflow as tf
4
+ import numpy as np
5
+ from PIL import Image
6
+ import io
7
+ import httpx
8
+
9
+ app = FastAPI()
10
+
11
+ # Enable CORS
12
+ app.add_middleware(
13
+ CORSMiddleware,
14
+ allow_origins=["*"],
15
+ allow_credentials=True,
16
+ allow_methods=["*"],
17
+ allow_headers=["*"],
18
+ )
19
+
20
+ # Load the MRI detector model
21
+ mri_detector_model = tf.keras.models.load_model('mri_detector.h5')
22
+
23
+ # Class labels for MRI detector
24
+ MRI_CLASS_LABELS = ["Brain MRI", "Not a Brain MRI"]
25
+
26
+ # Dementia API endpoint
27
+ DEMENTIA_API_URL = "https://arittrabag-dementia-backend.hf.space/analyze"
28
+
29
+ def preprocess_image_for_mri_detection(image_bytes):
30
+ """
31
+ Preprocess image for MRI detection model
32
+ Based on the original preprocessing: resize to (224, 224) and normalize by /255.0
33
+ """
34
+ # Open image from bytes
35
+ img = Image.open(io.BytesIO(image_bytes))
36
+
37
+ # Convert to RGB if needed
38
+ if img.mode != 'RGB':
39
+ img = img.convert('RGB')
40
+
41
+ # Resize to model's expected input size (224, 224)
42
+ img = img.resize((224, 224))
43
+
44
+ # Convert to numpy array and preprocess exactly like the original
45
+ img_array = np.array(img) / 255.0 # Normalize pixel values
46
+ img_array = np.expand_dims(img_array, axis=0) # Add batch dimension
47
+
48
+ return img_array
49
+
50
+ async def call_dementia_api(image_bytes):
51
+ """
52
+ Call the dementia detection API with the uploaded image
53
+ """
54
+ try:
55
+ async with httpx.AsyncClient(timeout=30.0) as client:
56
+ files = {"file": ("image.jpg", image_bytes, "image/jpeg")}
57
+ response = await client.post(DEMENTIA_API_URL, files=files)
58
+ response.raise_for_status()
59
+ return response.json()
60
+ except httpx.RequestError as e:
61
+ raise HTTPException(status_code=503, detail=f"Failed to connect to dementia API: {str(e)}")
62
+ except httpx.HTTPStatusError as e:
63
+ raise HTTPException(status_code=e.response.status_code, detail=f"Dementia API error: {e.response.text}")
64
+
65
+ @app.get("/")
66
+ async def root():
67
+ return {"message": "MRI Detector API - Upload an image to check if it's an MRI scan"}
68
+
69
+ @app.post("/detect")
70
+ async def detect_and_analyze(file: UploadFile = File(...)):
71
+ """
72
+ Main endpoint that:
73
+ 1. Checks if uploaded image is an MRI
74
+ 2. If MRI, calls dementia detection API
75
+ 3. Returns combined results
76
+ """
77
+ # Validate file type
78
+ if not file.content_type or not file.content_type.startswith('image/'):
79
+ raise HTTPException(status_code=400, detail="File must be an image")
80
+
81
+ # Read image file
82
+ contents = await file.read()
83
+
84
+ try:
85
+ # Step 1: Check if image is MRI
86
+ img_array = preprocess_image_for_mri_detection(contents)
87
+ mri_predictions = mri_detector_model.predict(img_array)
88
+
89
+ # Get MRI detection results
90
+ mri_confidences = mri_predictions[0].tolist()
91
+ predicted_class_idx = np.argmax(mri_confidences)
92
+ predicted_class = MRI_CLASS_LABELS[predicted_class_idx]
93
+ mri_confidence = float(mri_confidences[predicted_class_idx])
94
+
95
+ # Create MRI confidence dictionary
96
+ mri_confidence_dict = {label: float(conf) for label, conf in zip(MRI_CLASS_LABELS, mri_confidences)}
97
+
98
+ response = {
99
+ "isMRI": predicted_class == "Brain MRI",
100
+ "mriConfidence": mri_confidence,
101
+ "mriClassification": {
102
+ "predictedClass": predicted_class,
103
+ "confidences": mri_confidence_dict
104
+ }
105
+ }
106
+
107
+ # Step 2: If it's an MRI, call dementia detection API
108
+ if predicted_class == "Brain MRI":
109
+ try:
110
+ dementia_results = await call_dementia_api(contents)
111
+ response["dementiaAnalysis"] = dementia_results
112
+ response["status"] = "analysis_complete"
113
+ response["message"] = "Image identified as Brain MRI scan. Dementia analysis completed."
114
+ except Exception as e:
115
+ response["status"] = "mri_detected_but_analysis_failed"
116
+ response["message"] = f"Image identified as Brain MRI scan, but dementia analysis failed: {str(e)}"
117
+ response["error"] = str(e)
118
+ else:
119
+ response["status"] = "not_mri"
120
+ response["message"] = f"Image identified as {predicted_class} with {mri_confidence*100:.1f}% confidence. Dementia analysis not performed as this is not a Brain MRI scan."
121
+
122
+ return response
123
+
124
+ except Exception as e:
125
+ raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
126
+
127
+ @app.get("/health")
128
+ async def health_check():
129
+ return {"status": "healthy", "message": "MRI Detector API is running"}
mri_detector.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:03055c27dc993ff330e8441fd042c12a6916ad74f66d9ba7fed4bc0585bce276
3
+ size 134077984
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ tensorflow
4
+ numpy
5
+ pillow
6
+ httpx
7
+ python-multipart