arittrabag commited on
Commit
678235f
·
verified ·
1 Parent(s): c5f59af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -128
app.py CHANGED
@@ -1,129 +1,129 @@
1
- from fastapi import FastAPI, UploadFile, File, HTTPException
2
- from fastapi.middleware.cors import CORSMiddleware
3
- import tensorflow as tf
4
- import numpy as np
5
- from PIL import Image
6
- import io
7
- import httpx
8
-
9
- app = FastAPI()
10
-
11
- # Enable CORS
12
- app.add_middleware(
13
- CORSMiddleware,
14
- allow_origins=["*"],
15
- allow_credentials=True,
16
- allow_methods=["*"],
17
- allow_headers=["*"],
18
- )
19
-
20
- # Load the MRI detector model
21
- mri_detector_model = tf.keras.models.load_model('mri_detector.h5')
22
-
23
- # Class labels for MRI detector
24
- MRI_CLASS_LABELS = ["Brain MRI", "Not a Brain MRI"]
25
-
26
- # Dementia API endpoint
27
- DEMENTIA_API_URL = "https://arittrabag-dementia-backend.hf.space/analyze"
28
-
29
- def preprocess_image_for_mri_detection(image_bytes):
30
- """
31
- Preprocess image for MRI detection model
32
- Based on the original preprocessing: resize to (224, 224) and normalize by /255.0
33
- """
34
- # Open image from bytes
35
- img = Image.open(io.BytesIO(image_bytes))
36
-
37
- # Convert to RGB if needed
38
- if img.mode != 'RGB':
39
- img = img.convert('RGB')
40
-
41
- # Resize to model's expected input size (224, 224)
42
- img = img.resize((224, 224))
43
-
44
- # Convert to numpy array and preprocess exactly like the original
45
- img_array = np.array(img) / 255.0 # Normalize pixel values
46
- img_array = np.expand_dims(img_array, axis=0) # Add batch dimension
47
-
48
- return img_array
49
-
50
- async def call_dementia_api(image_bytes):
51
- """
52
- Call the dementia detection API with the uploaded image
53
- """
54
- try:
55
- async with httpx.AsyncClient(timeout=30.0) as client:
56
- files = {"file": ("image.jpg", image_bytes, "image/jpeg")}
57
- response = await client.post(DEMENTIA_API_URL, files=files)
58
- response.raise_for_status()
59
- return response.json()
60
- except httpx.RequestError as e:
61
- raise HTTPException(status_code=503, detail=f"Failed to connect to dementia API: {str(e)}")
62
- except httpx.HTTPStatusError as e:
63
- raise HTTPException(status_code=e.response.status_code, detail=f"Dementia API error: {e.response.text}")
64
-
65
- @app.get("/")
66
- async def root():
67
- return {"message": "MRI Detector API - Upload an image to check if it's an MRI scan"}
68
-
69
- @app.post("/detect")
70
- async def detect_and_analyze(file: UploadFile = File(...)):
71
- """
72
- Main endpoint that:
73
- 1. Checks if uploaded image is an MRI
74
- 2. If MRI, calls dementia detection API
75
- 3. Returns combined results
76
- """
77
- # Validate file type
78
- if not file.content_type or not file.content_type.startswith('image/'):
79
- raise HTTPException(status_code=400, detail="File must be an image")
80
-
81
- # Read image file
82
- contents = await file.read()
83
-
84
- try:
85
- # Step 1: Check if image is MRI
86
- img_array = preprocess_image_for_mri_detection(contents)
87
- mri_predictions = mri_detector_model.predict(img_array)
88
-
89
- # Get MRI detection results
90
- mri_confidences = mri_predictions[0].tolist()
91
- predicted_class_idx = np.argmax(mri_confidences)
92
- predicted_class = MRI_CLASS_LABELS[predicted_class_idx]
93
- mri_confidence = float(mri_confidences[predicted_class_idx])
94
-
95
- # Create MRI confidence dictionary
96
- mri_confidence_dict = {label: float(conf) for label, conf in zip(MRI_CLASS_LABELS, mri_confidences)}
97
-
98
- response = {
99
- "isMRI": predicted_class == "Brain MRI",
100
- "mriConfidence": mri_confidence,
101
- "mriClassification": {
102
- "predictedClass": predicted_class,
103
- "confidences": mri_confidence_dict
104
- }
105
- }
106
-
107
- # Step 2: If it's an MRI, call dementia detection API
108
- if predicted_class == "Brain MRI":
109
- try:
110
- dementia_results = await call_dementia_api(contents)
111
- response["dementiaAnalysis"] = dementia_results
112
- response["status"] = "analysis_complete"
113
- response["message"] = "Image identified as Brain MRI scan. Dementia analysis completed."
114
- except Exception as e:
115
- response["status"] = "mri_detected_but_analysis_failed"
116
- response["message"] = f"Image identified as Brain MRI scan, but dementia analysis failed: {str(e)}"
117
- response["error"] = str(e)
118
- else:
119
- response["status"] = "not_mri"
120
- response["message"] = f"Image identified as {predicted_class} with {mri_confidence*100:.1f}% confidence. Dementia analysis not performed as this is not a Brain MRI scan."
121
-
122
- return response
123
-
124
- except Exception as e:
125
- raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
126
-
127
- @app.get("/health")
128
- async def health_check():
129
  return {"status": "healthy", "message": "MRI Detector API is running"}
 
1
+ from fastapi import FastAPI, UploadFile, File, HTTPException
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ import tensorflow as tf
4
+ import numpy as np
5
+ from PIL import Image
6
+ import io
7
+ import httpx
8
+
9
+ app = FastAPI()
10
+
11
+ # Enable CORS
12
+ app.add_middleware(
13
+ CORSMiddleware,
14
+ allow_origins=["*"],
15
+ allow_credentials=True,
16
+ allow_methods=["*"],
17
+ allow_headers=["*"],
18
+ )
19
+
20
+ # Load the MRI detector model
21
+ mri_detector_model = tf.keras.models.load_model('mri_detector.h5')
22
+
23
+ # Class labels for MRI detector
24
+ MRI_CLASS_LABELS = ["Brain MRI", "Not a Brain MRI"]
25
+
26
+ # Dementia API endpoint
27
+ DEMENTIA_API_URL = "https://arittrabag-alzheimers-h4b.hf.space/analyze"
28
+
29
+ def preprocess_image_for_mri_detection(image_bytes):
30
+ """
31
+ Preprocess image for MRI detection model
32
+ Based on the original preprocessing: resize to (224, 224) and normalize by /255.0
33
+ """
34
+ # Open image from bytes
35
+ img = Image.open(io.BytesIO(image_bytes))
36
+
37
+ # Convert to RGB if needed
38
+ if img.mode != 'RGB':
39
+ img = img.convert('RGB')
40
+
41
+ # Resize to model's expected input size (224, 224)
42
+ img = img.resize((224, 224))
43
+
44
+ # Convert to numpy array and preprocess exactly like the original
45
+ img_array = np.array(img) / 255.0 # Normalize pixel values
46
+ img_array = np.expand_dims(img_array, axis=0) # Add batch dimension
47
+
48
+ return img_array
49
+
50
+ async def call_dementia_api(image_bytes):
51
+ """
52
+ Call the dementia detection API with the uploaded image
53
+ """
54
+ try:
55
+ async with httpx.AsyncClient(timeout=30.0) as client:
56
+ files = {"file": ("image.jpg", image_bytes, "image/jpeg")}
57
+ response = await client.post(DEMENTIA_API_URL, files=files)
58
+ response.raise_for_status()
59
+ return response.json()
60
+ except httpx.RequestError as e:
61
+ raise HTTPException(status_code=503, detail=f"Failed to connect to dementia API: {str(e)}")
62
+ except httpx.HTTPStatusError as e:
63
+ raise HTTPException(status_code=e.response.status_code, detail=f"Dementia API error: {e.response.text}")
64
+
65
+ @app.get("/")
66
+ async def root():
67
+ return {"message": "MRI Detector API - Upload an image to check if it's an MRI scan"}
68
+
69
+ @app.post("/detect")
70
+ async def detect_and_analyze(file: UploadFile = File(...)):
71
+ """
72
+ Main endpoint that:
73
+ 1. Checks if uploaded image is an MRI
74
+ 2. If MRI, calls dementia detection API
75
+ 3. Returns combined results
76
+ """
77
+ # Validate file type
78
+ if not file.content_type or not file.content_type.startswith('image/'):
79
+ raise HTTPException(status_code=400, detail="File must be an image")
80
+
81
+ # Read image file
82
+ contents = await file.read()
83
+
84
+ try:
85
+ # Step 1: Check if image is MRI
86
+ img_array = preprocess_image_for_mri_detection(contents)
87
+ mri_predictions = mri_detector_model.predict(img_array)
88
+
89
+ # Get MRI detection results
90
+ mri_confidences = mri_predictions[0].tolist()
91
+ predicted_class_idx = np.argmax(mri_confidences)
92
+ predicted_class = MRI_CLASS_LABELS[predicted_class_idx]
93
+ mri_confidence = float(mri_confidences[predicted_class_idx])
94
+
95
+ # Create MRI confidence dictionary
96
+ mri_confidence_dict = {label: float(conf) for label, conf in zip(MRI_CLASS_LABELS, mri_confidences)}
97
+
98
+ response = {
99
+ "isMRI": predicted_class == "Brain MRI",
100
+ "mriConfidence": mri_confidence,
101
+ "mriClassification": {
102
+ "predictedClass": predicted_class,
103
+ "confidences": mri_confidence_dict
104
+ }
105
+ }
106
+
107
+ # Step 2: If it's an MRI, call dementia detection API
108
+ if predicted_class == "Brain MRI":
109
+ try:
110
+ dementia_results = await call_dementia_api(contents)
111
+ response["dementiaAnalysis"] = dementia_results
112
+ response["status"] = "analysis_complete"
113
+ response["message"] = "Image identified as Brain MRI scan. Dementia analysis completed."
114
+ except Exception as e:
115
+ response["status"] = "mri_detected_but_analysis_failed"
116
+ response["message"] = f"Image identified as Brain MRI scan, but dementia analysis failed: {str(e)}"
117
+ response["error"] = str(e)
118
+ else:
119
+ response["status"] = "not_mri"
120
+ response["message"] = f"Image identified as {predicted_class} with {mri_confidence*100:.1f}% confidence. Dementia analysis not performed as this is not a Brain MRI scan."
121
+
122
+ return response
123
+
124
+ except Exception as e:
125
+ raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
126
+
127
+ @app.get("/health")
128
+ async def health_check():
129
  return {"status": "healthy", "message": "MRI Detector API is running"}