Naman2302 commited on
Commit
ebf9ab4
·
verified ·
1 Parent(s): 9d19f07

added the valid xray image detection code

Browse files
Files changed (1) hide show
  1. app.py +43 -6
app.py CHANGED
@@ -4,7 +4,6 @@ import cv2
4
  import tempfile
5
  import os
6
  import sys
7
-
8
  # Add project root to Python path
9
  sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
10
 
@@ -13,7 +12,7 @@ from src.predict_fracture import FracturePredictor
13
 
14
  # Get current script location
15
  current_dir = os.path.dirname(os.path.abspath(__file__))
16
- project_root = os.path.dirname(current_dir) # Go up from app/ to project root
17
 
18
  # CORRECTED MODEL PATHS
19
  MODEL_PATH = 'models/fracture_detection_model.joblib'
@@ -29,12 +28,51 @@ if os.path.exists(MODEL_PATH) and os.path.exists(ENCODER_PATH):
29
  predictor = FracturePredictor(model_path=MODEL_PATH, encoder_path=ENCODER_PATH)
30
  else:
31
  print("ERROR: Model files not found. Please run training first.")
 
 
 
32
  exit(1)
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  def predict_fracture(img):
35
  """Process uploaded image and return prediction results"""
36
  try:
37
- # Handle different input types
 
 
 
 
38
  if isinstance(img, np.ndarray):
39
  # Convert to BGR format for OpenCV
40
  if img.shape[2] == 4: # RGBA image
@@ -50,10 +88,10 @@ def predict_fracture(img):
50
  # Already a file path
51
  tmp_path = img
52
 
53
- # Get prediction
54
  label, confidence, vis_path = predictor.predict(tmp_path)
55
 
56
- # Read visualization image
57
  vis_img = cv2.imread(vis_path)
58
  if vis_img is not None:
59
  vis_img = cv2.cvtColor(vis_img, cv2.COLOR_BGR2RGB)
@@ -92,4 +130,3 @@ if __name__ == "__main__":
92
  server_port=7860,
93
  share=True # Add this line to enable public access
94
  )
95
-
 
4
  import tempfile
5
  import os
6
  import sys
 
7
  # Add project root to Python path
8
  sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
9
 
 
12
 
13
  # Get current script location
14
  current_dir = os.path.dirname(os.path.abspath(__file__))
15
+ project_root = os.path.dirname(current_dir) # Go up from app/ to project root
16
 
17
  # CORRECTED MODEL PATHS
18
  MODEL_PATH = 'models/fracture_detection_model.joblib'
 
28
  predictor = FracturePredictor(model_path=MODEL_PATH, encoder_path=ENCODER_PATH)
29
  else:
30
  print("ERROR: Model files not found. Please run training first.")
31
+ # Provide detailed troubleshooting help
32
+ print(f"Current working directory: {os.getcwd()}")
33
+ print(f"Files in models directory: {os.listdir(os.path.join(project_root, 'models'))}")
34
  exit(1)
35
 
36
+ def is_xray_image(img):
37
+ """Validate if image is an X-ray using intensity distribution"""
38
+ try:
39
+ if isinstance(img, np.ndarray):
40
+ # Convert to grayscale if needed
41
+ if len(img.shape) == 3 and img.shape[2] == 3:
42
+ img_gray = np.dot(img[...,:3], [0.2989, 0.5870, 0.1140])
43
+ elif len(img.shape) == 3 and img.shape[2] == 4:
44
+ img_gray = np.dot(img[...,:3], [0.2989, 0.5870, 0.1140])
45
+ else:
46
+ img_gray = img if len(img.shape) == 2 else img[:, :, 0]
47
+ else:
48
+ # Handle file path
49
+ img_array = cv2.imread(img, cv2.IMREAD_GRAYSCALE)
50
+ if img_array is None:
51
+ return False
52
+ img_gray = np.array(img_array)
53
+
54
+ # Calculate statistics
55
+ mean_intensity = np.mean(img_gray)
56
+ std_intensity = np.std(img_gray)
57
+
58
+ # X-ray characteristics:
59
+ # - Moderate brightness (not too dark/light)
60
+ # - Reasonable contrast
61
+ is_valid = (20 <= mean_intensity <= 230) and (std_intensity >= 10)
62
+ return is_valid
63
+
64
+ except Exception as e:
65
+ print(f"Validation error: {str(e)}")
66
+ return False
67
+
68
  def predict_fracture(img):
69
  """Process uploaded image and return prediction results"""
70
  try:
71
+ # Step 1: Validate if it's an X-ray
72
+ if not is_xray_image(img):
73
+ return "⚠️ Not an X-ray image", "Upload a valid X-ray", None
74
+
75
+ # Step 2: Process the image
76
  if isinstance(img, np.ndarray):
77
  # Convert to BGR format for OpenCV
78
  if img.shape[2] == 4: # RGBA image
 
88
  # Already a file path
89
  tmp_path = img
90
 
91
+ # Step 3: Get prediction
92
  label, confidence, vis_path = predictor.predict(tmp_path)
93
 
94
+ # Step 4: Read visualization
95
  vis_img = cv2.imread(vis_path)
96
  if vis_img is not None:
97
  vis_img = cv2.cvtColor(vis_img, cv2.COLOR_BGR2RGB)
 
130
  server_port=7860,
131
  share=True # Add this line to enable public access
132
  )