rjaditya commited on
Commit
855ba03
·
verified ·
1 Parent(s): 553ee32

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -13
app.py CHANGED
@@ -1,23 +1,31 @@
1
-
2
  import gradio as gr
3
  from ultralytics import YOLO
4
  from PIL import Image
5
- import numpy as np
6
 
7
- # Define the path to the trained YOLO model weights
8
- # This path should be accessible in the environment where the app is run
9
- model_path = "yolo_training/sickle_cls_model/weights/best.pt" # Update this path if needed
 
 
 
 
 
 
 
 
 
 
10
 
11
- # Load the YOLO model
12
  model = YOLO(model_path)
13
 
14
- # Define the list of class names used during model training
15
  classes = ['sickle', 'non_sickle', 'AIN']
16
 
17
- # Create a function that takes an image and returns the prediction
18
  def predict_image(img):
19
- # Make prediction
20
- # YOLO expects a list of images for prediction
21
  results = model([img])
22
 
23
  # Extract top prediction from the first image result
@@ -34,8 +42,13 @@ interface = gr.Interface(
34
  inputs=gr.Image(type="pil"),
35
  outputs=[gr.Label(), gr.Number()],
36
  title="Sickle Cell Classification",
37
- description="Upload a blood smear image to classify it as Sickle Cell (sickle), Non-Sickle (non_sickle), or Artifact/Impurities/Noise (AIN)."
 
 
 
 
38
  )
39
 
40
- # Launch the interface (for local testing)
41
- # interface.launch()
 
 
1
+ import os
2
  import gradio as gr
3
  from ultralytics import YOLO
4
  from PIL import Image
 
5
 
6
+ # Suppress Ultralytics write warning by setting config dir
7
+ os.environ["YOLO_CONFIG_DIR"] = "UltralyticsConfig"
8
+ os.makedirs("UltralyticsConfig", exist_ok=True)
9
+
10
+ # Path to YOLO model weights
11
+ model_path = "yolo_training/sickle_cls_model/weights/best.pt"
12
+
13
+ # Check if model exists
14
+ if not os.path.exists(model_path):
15
+ raise FileNotFoundError(
16
+ f"Model file not found at {model_path}. "
17
+ "Make sure best.pt is uploaded to this path in your repo."
18
+ )
19
 
20
+ # Load YOLO model
21
  model = YOLO(model_path)
22
 
23
+ # Class names
24
  classes = ['sickle', 'non_sickle', 'AIN']
25
 
26
+ # Prediction function
27
  def predict_image(img):
28
+ # YOLO expects a list of images
 
29
  results = model([img])
30
 
31
  # Extract top prediction from the first image result
 
42
  inputs=gr.Image(type="pil"),
43
  outputs=[gr.Label(), gr.Number()],
44
  title="Sickle Cell Classification",
45
+ description=(
46
+ "Upload a blood smear image to classify it as "
47
+ "Sickle Cell (sickle), Non-Sickle (non_sickle), "
48
+ "or Artifact/Impurities/Noise (AIN)."
49
+ )
50
  )
51
 
52
+ # Launch interface
53
+ if __name__ == "__main__":
54
+ interface.launch()