SavlonBhai commited on
Commit
f996c5a
·
verified ·
1 Parent(s): 6614d3d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -33
app.py CHANGED
@@ -3,42 +3,42 @@ from ultralytics import YOLO
3
  from PIL import Image
4
  import numpy as np
5
 
6
- # Load YOLO model (update path if needed)
7
- model = YOLO('best_animal_classifier.pt')
8
 
9
- class_names = ['butterflies', 'chickens', 'elephants', 'horses', 'spiders', 'squirrels']
 
10
 
11
- def predict_animal(image):
12
- if image is None:
13
- return {}
14
-
15
- # Convert numpy array input to PIL Image if needed
16
  if isinstance(image, np.ndarray):
17
  image = Image.fromarray(image)
18
-
19
- # Run prediction quietly
20
- results = model.predict(image, verbose=False)
21
-
22
- try:
23
- probs = results[0].probs.data.cpu().numpy()
24
- except AttributeError:
25
- # Fallback to uniform probabilities if probs unavailable
26
- probs = np.ones(len(class_names)) / len(class_names)
27
-
28
- return {class_names[i]: float(probs[i]) for i in range(len(class_names))}
29
-
30
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
31
- gr.Markdown("# 🐾 Animal Type Classifier")
32
- gr.Markdown("Upload an image of an animal below and get predictions for butterflies, chickens, elephants, horses, spiders, or squirrels.")
33
-
34
- with gr.Row():
35
- img_input = gr.Image(type="pil", label="Upload Animal Image")
36
- label_output = gr.Label(num_top_classes=6, label="Prediction Scores")
37
-
38
- predict_button = gr.Button("Classify Animal")
39
- predict_button.click(fn=predict_animal, inputs=img_input, outputs=label_output)
40
-
41
- gr.Markdown("Developed with Ultralytics YOLO and Gradio framework.")
 
 
42
 
43
  if __name__ == "__main__":
44
- demo.launch()
 
3
  from PIL import Image
4
  import numpy as np
5
 
6
+ # Load your pretrained YOLOv8 model (replace with your model path if customized)
7
+ model = YOLO("yolov8n.pt") # Using tiny YOLOv8 pretrained weights for example
8
 
9
+ # Define the class names according to your trained model classes
10
+ class_names = ["butterfly", "chicken", "elephant", "horse", "spider", "squirrel"]
11
 
12
+ def classify_animal(image):
13
+ # Convert numpy array to PIL Image
 
 
 
14
  if isinstance(image, np.ndarray):
15
  image = Image.fromarray(image)
16
+ # YOLO expects numpy array input, convert back to np array
17
+ img_np = np.array(image)
18
+
19
+ # Perform prediction
20
+ results = model(img_np)
21
+
22
+ # Extract top prediction
23
+ if results and len(results) > 0:
24
+ boxes = results[0].boxes
25
+ if boxes is not None and len(boxes) > 0:
26
+ # Get the highest confidence prediction
27
+ best_box = boxes[0]
28
+ conf = best_box.conf[0].item()
29
+ class_id = int(best_box.cls[0].item())
30
+ class_name = class_names[class_id] if class_id < len(class_names) else "Unknown"
31
+ return f"Predicted: {class_name} (Confidence: {conf:.2f})"
32
+ return "No animal detected."
33
+
34
+ # Build Gradio interface
35
+ iface = gr.Interface(
36
+ fn=classify_animal,
37
+ inputs=gr.Image(type="numpy", label="Upload Animal Image"),
38
+ outputs="text",
39
+ title="Real-Time Animal Type Classification",
40
+ description="Upload an image of an animal to classify it among butterfly, chicken, elephant, horse, spider, and squirrel."
41
+ )
42
 
43
  if __name__ == "__main__":
44
+ iface.launch()