SavlonBhai commited on
Commit
278cdaa
·
verified ·
1 Parent(s): 0e98f3e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -107
app.py CHANGED
@@ -10,127 +10,71 @@ import numpy as np
10
  import logging
11
  import sys
12
  import os
13
- from typing import Optional, Tuple
14
 
15
- # ============================================================================
16
- # LOGGING CONFIGURATION
17
- # ============================================================================
18
  logging.basicConfig(
19
  level=logging.INFO,
20
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
21
- handlers=[
22
- logging.FileHandler('animal_classifier.log'),
23
- logging.StreamHandler(sys.stdout)
24
- ]
25
  )
26
  logger = logging.getLogger(__name__)
27
 
28
- # ============================================================================
29
- # CONFIGURATION
30
- # ============================================================================
31
  MODEL_PATH = "best_animal_classifier.pt"
32
  CLASS_NAMES = ["butterfly", "chicken", "elephant", "horse", "spider", "squirrel"]
33
- CONFIDENCE_THRESHOLD = 0.5
34
- MIN_DETECTIONS = 1
35
 
36
- # ============================================================================
37
- # GLOBAL MODEL VARIABLE
38
- # ============================================================================
39
- model = None
40
-
41
- # ============================================================================
42
- # MODEL INITIALIZATION WITH EXCEPTION HANDLING
43
- # ============================================================================
44
- def load_model() -> Optional[YOLO]:
45
- """
46
- Load the YOLO model with comprehensive error handling.
47
-
48
- Returns:
49
- YOLO: Loaded model object or None if loading fails
50
- """
51
- global model
52
-
53
- try:
54
- logger.info(f"Attempting to load model from: {MODEL_PATH}")
55
-
56
- # Check if model file exists
57
- if not os.path.exists(MODEL_PATH):
58
- logger.error(f"Model file not found at: {MODEL_PATH}")
59
- raise FileNotFoundError(
60
- f"Model file '{MODEL_PATH}' does not exist. "
61
- f"Please ensure the file is in the correct location."
62
- )
63
-
64
- # Check if file has read permissions
65
- if not os.access(MODEL_PATH, os.R_OK):
66
- logger.error(f"No read permission for model file: {MODEL_PATH}")
67
- raise PermissionError(
68
- f"No read permission for model file: {MODEL_PATH}"
69
- )
70
-
71
- # Load the model
72
  model = YOLO(MODEL_PATH)
73
  logger.info("✅ Model loaded successfully!")
74
-
75
- return model
76
-
77
- except FileNotFoundError as e:
78
- logger.error(f"FileNotFoundError: {e}")
79
- return None
80
-
81
- except PermissionError as e:
82
- logger.error(f"PermissionError: {e}")
83
- return None
84
-
85
- except Exception as e:
86
- logger.error(f"Unexpected error loading model: {type(e).__name__}: {e}")
87
- return None
88
-
89
 
90
- # ============================================================================
91
- # CLASSIFICATION FUNCTION WITH ROBUST ERROR HANDLING
92
- # ============================================================================
93
- def classify_animal(image: Optional[np.ndarray]) -> str:
94
- """
95
- Classify an animal in the provided image using YOLOv8.
96
-
97
- Args:
98
- image (Optional[np.ndarray]): Input image as numpy array or PIL Image
99
-
100
- Returns:
101
- str: Classification result with confidence score or error message
102
- """
103
 
 
 
 
104
  try:
105
- # ========== INPUT VALIDATION ==========
106
- if image is None:
107
- logger.warning("No image provided")
108
- return "❌ Error: No image provided. Please upload an image."
109
-
110
- logger.info("Image received for classification")
111
 
112
- # ========== MODEL AVAILABILITY CHECK ==========
113
- if model is None:
114
- logger.error("Model is not loaded")
115
- return "❌ Critical Error: Model not loaded. Please restart the application."
116
 
117
- # ========== IMAGE TYPE CONVERSION ==========
118
- try:
119
- if isinstance(image, np.ndarray):
120
- # Validate numpy array dimensions
121
- if image.ndim not in [2, 3, 4]:
122
- logger.error(f"Invalid image dimensions: {image.ndim}")
123
- return "❌ Error: Invalid image dimensions. Expected 2D, 3D, or 4D array."
124
-
125
- # Validate data type
126
- if not np.issubdtype(image.dtype, np.integer):
127
- logger.warning(f"Unexpected image dtype: {image.dtype}, attempting conversion")
128
- image = image.astype('uint8')
129
-
130
- # Convert to PIL Image
131
- image_pil = Image.fromarray(image.astype('uint8'))
132
- logger.debug("Converted numpy array to PIL Image")
133
 
134
- elif isinstance(image, Image.Image):
135
- image_pil = image
136
- lo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  import logging
11
  import sys
12
  import os
13
+ from typing import Optional
14
 
15
+ # Logging Configuration
 
 
16
  logging.basicConfig(
17
  level=logging.INFO,
18
+ format='%(asctime)s - %(levelname)s - %(message)s',
19
+ handlers=[logging.StreamHandler(sys.stdout)]
 
 
 
20
  )
21
  logger = logging.getLogger(__name__)
22
 
23
+ # Configuration
 
 
24
  MODEL_PATH = "best_animal_classifier.pt"
25
  CLASS_NAMES = ["butterfly", "chicken", "elephant", "horse", "spider", "squirrel"]
 
 
26
 
27
+ # Load the model
28
+ try:
29
+ if os.path.exists(MODEL_PATH):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  model = YOLO(MODEL_PATH)
31
  logger.info("✅ Model loaded successfully!")
32
+ else:
33
+ logger.error(f"❌ Model file not found at {MODEL_PATH}")
34
+ model = None
35
+ except Exception as e:
36
+ logger.error(f"❌ Error loading model: {e}")
37
+ model = None
 
 
 
 
 
 
 
 
 
38
 
39
+ def classify_animal(image):
40
+ if image is None:
41
+ return "Please upload an image."
 
 
 
 
 
 
 
 
 
 
42
 
43
+ if model is None:
44
+ return "Model not loaded. Check server logs."
45
+
46
  try:
47
+ # Run inference
48
+ results = model(image)
 
 
 
 
49
 
50
+ # YOLOv8 classification returns a list of results
51
+ # We take the top prediction from the first result
52
+ result = results[0]
 
53
 
54
+ if result.probs is not None:
55
+ # Get index of the highest probability
56
+ top1_idx = result.probs.top1
57
+ conf = result.probs.top1conf.item()
58
+ label = result.names[top1_idx]
 
 
 
 
 
 
 
 
 
 
 
59
 
60
+ return f"Prediction: {label.upper()} ({conf:.2%})"
61
+ else:
62
+ return "No animals detected or classification failed."
63
+
64
+ except Exception as e:
65
+ logger.error(f"Inference error: {e}")
66
+ return f"Error during classification: {str(e)}"
67
+
68
+ # Gradio Interface
69
+ demo = gr.Interface(
70
+ fn=classify_animal,
71
+ inputs=gr.Image(type="pil", label="Upload Animal Image"),
72
+ outputs=gr.Textbox(label="Result"),
73
+ title="🐾 Animal Type Classifier",
74
+ description="Upload a photo of a butterfly, chicken, elephant, horse, spider, or squirrel to identify it.",
75
+ examples=[["example_elephant.jpg"]] if os.path.exists("example_elephant.jpg") else None,
76
+ cache_examples=False
77
+ )
78
+
79
+ if __name__ == "__main__":
80
+ demo.launch(server_name="0.0.0.0", server_port=7860)