TheDeepDas commited on
Commit
8ff3e07
·
1 Parent(s): 8e1f75a
Files changed (3) hide show
  1. app/services/ml_model.py +13 -2
  2. start-hf.sh +0 -9
  3. test_yolo_model.py +0 -158
app/services/ml_model.py CHANGED
@@ -80,13 +80,24 @@ class IncidentClassifier:
80
 
81
  def predict(self, description, name=""):
82
  """Predict threat type and severity for an incident"""
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  if not self.is_trained:
84
  # Fallback to rule-based classification
85
  return self._rule_based_classification(description, name)
86
 
87
  try:
88
- # Combine name and description
89
- combined_text = f"{name} {description}".strip()
90
  preprocessed_text = self.preprocess_text(combined_text)
91
 
92
  if not preprocessed_text:
 
80
 
81
  def predict(self, description, name=""):
82
  """Predict threat type and severity for an incident"""
83
+ # Combine name and description for keyword checking
84
+ combined_text = f"{name} {description}".lower()
85
+
86
+ # Basic keyword check for plastic - classify as Chemical threat with medium severity
87
+ if 'plastic' in combined_text:
88
+ logger.info("Plastic keyword detected - using basic classification")
89
+ return {
90
+ 'threat': 'Chemical', # Use Chemical as the threat class (as defined in model training)
91
+ 'severity': 'medium',
92
+ 'threat_confidence': 0.95, # High confidence for keyword match
93
+ 'severity_confidence': 0.92
94
+ }
95
+
96
  if not self.is_trained:
97
  # Fallback to rule-based classification
98
  return self._rule_based_classification(description, name)
99
 
100
  try:
 
 
101
  preprocessed_text = self.preprocess_text(combined_text)
102
 
103
  if not preprocessed_text:
start-hf.sh CHANGED
@@ -12,14 +12,5 @@ export ALLOWED_ORIGINS=${ALLOWED_ORIGINS:-"*"}
12
  echo "📡 Port: ${PORT:-7860}"
13
  echo "🔗 Allowed Origins: $ALLOWED_ORIGINS"
14
 
15
- # Test YOLO model availability (quick test)
16
- echo "🤖 Testing YOLO model..."
17
- python test_yolo_model.py
18
- if [ $? -eq 0 ]; then
19
- echo "✓ YOLO model ready"
20
- else
21
- echo "⚠️ YOLO model test failed, object detection may not work"
22
- fi
23
-
24
  # Start the FastAPI application
25
  exec uvicorn app.main:app --host 0.0.0.0 --port ${PORT:-7860} --workers 1
 
12
  echo "📡 Port: ${PORT:-7860}"
13
  echo "🔗 Allowed Origins: $ALLOWED_ORIGINS"
14
 
 
 
 
 
 
 
 
 
 
15
  # Start the FastAPI application
16
  exec uvicorn app.main:app --host 0.0.0.0 --port ${PORT:-7860} --workers 1
test_yolo_model.py DELETED
@@ -1,158 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Test script to download and load YOLO model for marine pollution detection.
4
- This ensures the model is available before the main application starts.
5
- """
6
-
7
- import os
8
- import sys
9
- import logging
10
- from pathlib import Path
11
-
12
- # Configure logging
13
- logging.basicConfig(
14
- level=logging.INFO,
15
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
16
- )
17
- logger = logging.getLogger(__name__)
18
-
19
- def test_opencv():
20
- """Test if OpenCV is available"""
21
- try:
22
- import cv2
23
- logger.info(f"✓ OpenCV loaded successfully: {cv2.__version__}")
24
- return True
25
- except ImportError as e:
26
- logger.error(f"✗ OpenCV not available: {e}")
27
- return False
28
-
29
- def test_torch():
30
- """Test if PyTorch is available"""
31
- try:
32
- import torch
33
- logger.info(f"✓ PyTorch loaded successfully: {torch.__version__}")
34
- logger.info(f" CUDA available: {torch.cuda.is_available()}")
35
- if torch.cuda.is_available():
36
- logger.info(f" CUDA device: {torch.cuda.get_device_name(0)}")
37
- else:
38
- logger.info(" Using CPU for inference")
39
- return True
40
- except ImportError as e:
41
- logger.error(f"✗ PyTorch not available: {e}")
42
- return False
43
-
44
- def test_ultralytics():
45
- """Test if Ultralytics YOLO is available"""
46
- try:
47
- from ultralytics import YOLO
48
- logger.info("✓ Ultralytics YOLO loaded successfully")
49
- return True
50
- except ImportError as e:
51
- logger.error(f"✗ Ultralytics not available: {e}")
52
- return False
53
-
54
- def download_and_test_model():
55
- """Download and test the YOLO model"""
56
- try:
57
- from ultralytics import YOLO
58
- import torch
59
-
60
- # Set to CPU mode if no CUDA
61
- if not torch.cuda.is_available():
62
- os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
63
- logger.info("Forcing CPU mode (no CUDA available)")
64
-
65
- # Model to download (YOLOv8x - largest/most accurate)
66
- model_name = "yolov8x.pt"
67
- logger.info(f"Attempting to load/download {model_name}...")
68
-
69
- # Check if model already exists
70
- model_path = Path.home() / ".cache" / "ultralytics" / model_name
71
- if model_path.exists():
72
- logger.info(f"✓ Model already exists at: {model_path}")
73
- else:
74
- logger.info(f"Model not found, will download to: {model_path}")
75
-
76
- # Load the model (will auto-download if not present)
77
- logger.info("Loading model...")
78
- model = YOLO(model_name)
79
-
80
- # Verify model loaded
81
- if hasattr(model, 'model') and model.model is not None:
82
- logger.info(f"✓ Model loaded successfully!")
83
-
84
- # Get model info
85
- try:
86
- model_info = model.info()
87
- if isinstance(model_info, dict):
88
- logger.info(f" Model type: {model_info.get('model_type', 'unknown')}")
89
- elif hasattr(model_info, 'model_type'):
90
- logger.info(f" Model type: {model_info.model_type}")
91
- except Exception as e:
92
- logger.warning(f"Could not get detailed model info: {e}")
93
-
94
- # Test inference with a dummy image
95
- logger.info("Testing inference with dummy image...")
96
- import numpy as np
97
- import tempfile
98
- import cv2
99
-
100
- # Create a small test image
101
- test_img = np.zeros((100, 100, 3), dtype=np.uint8)
102
- temp_path = tempfile.mktemp(suffix='.jpg')
103
- cv2.imwrite(temp_path, test_img)
104
-
105
- # Run inference
106
- results = model(temp_path, verbose=False)
107
- os.unlink(temp_path)
108
-
109
- logger.info("✓ Model inference test successful!")
110
- return True
111
- else:
112
- logger.error("✗ Model loaded but verification failed")
113
- return False
114
-
115
- except Exception as e:
116
- logger.error(f"✗ Failed to download/test model: {e}", exc_info=True)
117
- return False
118
-
119
- def main():
120
- """Run all tests"""
121
- logger.info("=" * 60)
122
- logger.info("YOLO Model Test Suite")
123
- logger.info("=" * 60)
124
-
125
- results = {
126
- "OpenCV": test_opencv(),
127
- "PyTorch": test_torch(),
128
- "Ultralytics": test_ultralytics(),
129
- }
130
-
131
- # Only test model if all dependencies are available
132
- if all(results.values()):
133
- logger.info("\nAll dependencies available. Testing model...")
134
- results["YOLO Model"] = download_and_test_model()
135
- else:
136
- logger.error("\nMissing dependencies. Skipping model test.")
137
- results["YOLO Model"] = False
138
-
139
- # Print summary
140
- logger.info("\n" + "=" * 60)
141
- logger.info("Test Summary:")
142
- logger.info("=" * 60)
143
- for test, passed in results.items():
144
- status = "✓ PASS" if passed else "✗ FAIL"
145
- logger.info(f" {test:20s}: {status}")
146
-
147
- all_passed = all(results.values())
148
- logger.info("=" * 60)
149
-
150
- if all_passed:
151
- logger.info("✓ All tests passed! System ready for object detection.")
152
- return 0
153
- else:
154
- logger.error("✗ Some tests failed. Check logs above.")
155
- return 1
156
-
157
- if __name__ == "__main__":
158
- sys.exit(main())