Naman2302 commited on
Commit
9074b47
·
verified ·
1 Parent(s): d037fba

app.py created

Browse files
Files changed (1) hide show
  1. app.py +90 -0
app.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import cv2
4
+ import tempfile
5
+ import os
6
+ import sys
7
+
8
+ # Add project root to Python path
9
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
10
+
11
+ # Import predictor class
12
+ from src.predict_fracture import FracturePredictor
13
+
14
+ # Get current script location
15
+ current_dir = os.path.dirname(os.path.abspath(__file__))
16
+ project_root = os.path.dirname(current_dir) # Go up from app/ to project root
17
+
18
+ # CORRECTED MODEL PATHS
19
+ MODEL_PATH = 'models/fracture_detection_model.joblib'
20
+ ENCODER_PATH = 'models/label_encoder.joblib'
21
+ # Debugging output
22
+ print(f"Project root: {project_root}")
23
+ print(f"Model path: {MODEL_PATH}")
24
+ print(f"Model exists: {os.path.exists(MODEL_PATH)}")
25
+ print(f"Encoder exists: {os.path.exists(ENCODER_PATH)}")
26
+
27
+ # Initialize predictor only if files exist
28
+ if os.path.exists(MODEL_PATH) and os.path.exists(ENCODER_PATH):
29
+ predictor = FracturePredictor(model_path=MODEL_PATH, encoder_path=ENCODER_PATH)
30
+ else:
31
+ print("ERROR: Model files not found. Please run training first.")
32
+ exit(1)
33
+
34
+ def predict_fracture(img):
35
+ """Process uploaded image and return prediction results"""
36
+ try:
37
+ # Handle different input types
38
+ if isinstance(img, np.ndarray):
39
+ # Convert to BGR format for OpenCV
40
+ if img.shape[2] == 4: # RGBA image
41
+ img_bgr = cv2.cvtColor(img, cv2.COLOR_RGBA2BGR)
42
+ else: # RGB image
43
+ img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
44
+
45
+ # Save to temp file
46
+ with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
47
+ tmp_path = tmp.name
48
+ cv2.imwrite(tmp_path, img_bgr)
49
+ else:
50
+ # Already a file path
51
+ tmp_path = img
52
+
53
+ # Get prediction
54
+ label, confidence, vis_path = predictor.predict(tmp_path)
55
+
56
+ # Read visualization image
57
+ vis_img = cv2.imread(vis_path)
58
+ if vis_img is not None:
59
+ vis_img = cv2.cvtColor(vis_img, cv2.COLOR_BGR2RGB)
60
+
61
+ # Clean up temporary file
62
+ if isinstance(img, np.ndarray) and os.path.exists(tmp_path):
63
+ os.unlink(tmp_path)
64
+
65
+ return label, f"{confidence:.4f}", vis_img
66
+
67
+ except Exception as e:
68
+ print(f"Prediction error: {str(e)}")
69
+ return "Error", "N/A", None
70
+
71
+ # Create Gradio interface
72
+ iface = gr.Interface(
73
+ fn=predict_fracture,
74
+ inputs=gr.Image(label="Upload X-Ray Image"),
75
+ outputs=[
76
+ gr.Label(label="Prediction Result"),
77
+ gr.Textbox(label="Confidence Score"),
78
+ gr.Image(label="Prediction Visualization")
79
+ ],
80
+ title="🦴 Bone Fracture Detection System",
81
+ description="Upload an X-ray image to detect bone fractures using GLCM features and SVM classifier",
82
+ examples=[
83
+ [os.path.join("samples", "fractured_1.jpg")],
84
+ [os.path.join("samples", "normal_1.jpg")]
85
+ ],
86
+ flagging_mode="never"
87
+ )
88
+
89
+ if __name__ == "__main__":
90
+ iface.launch(server_name="0.0.0.0", server_port=7860)