File size: 6,577 Bytes
6c9c901
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ff3e07
 
 
 
 
 
 
 
 
 
 
 
 
6c9c901
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
# Inference-only ML model service
# Models are pre-trained and saved as .pkl files

import numpy as np
import re
from pathlib import Path
import joblib
import logging

logger = logging.getLogger(__name__)

# Minimal sklearn imports for model loading
try:
    from sklearn.feature_extraction.text import TfidfVectorizer
    from sklearn.ensemble import RandomForestClassifier
    SKLEARN_AVAILABLE = True
    logger.info("sklearn imported successfully")
except ImportError as e:
    SKLEARN_AVAILABLE = False
    logger.warning(f"sklearn not available: {e}. Using rule-based classification.")

# Get the directory where this file is located
BASE_DIR = Path(__file__).resolve().parent.parent.parent
MODEL_DIR = BASE_DIR / "models"
MODEL_DIR.mkdir(exist_ok=True)

class IncidentClassifier:
    def __init__(self):
        self.threat_model = None
        self.severity_model = None
        self.is_trained = False
        
        # Try to load pre-trained models automatically
        try:
            if self.load_models():
                logger.info("Pre-trained models loaded successfully")
            else:
                logger.warning("No pre-trained models found. Classification will use fallback rules.")
        except Exception as e:
            logger.warning(f"Failed to load models on initialization: {e}")
    
    def preprocess_text(self, text):
        """Clean and preprocess text data"""
        if text is None or text == "":
            return ""
        
        # Convert to lowercase
        text = str(text).lower()
        
        # Remove special characters but keep spaces
        text = re.sub(r'[^a-zA-Z0-9\s]', ' ', text)
        
        # Remove extra whitespaces
        text = re.sub(r'\s+', ' ', text).strip()
        
        return text
    
    def load_models(self):
        """Load trained models from disk"""
        if not SKLEARN_AVAILABLE:
            logger.warning("sklearn not available, cannot load models")
            return False
            
        try:
            threat_model_path = MODEL_DIR / "threat_model.pkl"
            severity_model_path = MODEL_DIR / "severity_model.pkl"
            
            if threat_model_path.exists() and severity_model_path.exists():
                self.threat_model = joblib.load(threat_model_path)
                self.severity_model = joblib.load(severity_model_path)
                self.is_trained = True
                logger.info("Models loaded successfully")
                return True
            else:
                logger.warning("Model files not found")
                return False
        except Exception as e:
            logger.error(f"Error loading models: {e}")
            return False
    
    def predict(self, description, name=""):
        """Predict threat type and severity for an incident"""
        # Combine name and description for keyword checking
        combined_text = f"{name} {description}".lower()
        
        # Basic keyword check for plastic - classify as Chemical threat with medium severity
        if 'plastic' in combined_text:
            logger.info("Plastic keyword detected - using basic classification")
            return {
                'threat': 'Chemical',  # Use Chemical as the threat class (as defined in model training)
                'severity': 'medium',
                'threat_confidence': 0.95,  # High confidence for keyword match
                'severity_confidence': 0.92
            }
        
        if not self.is_trained:
            # Fallback to rule-based classification
            return self._rule_based_classification(description, name)
        
        try:
            preprocessed_text = self.preprocess_text(combined_text)
            
            if not preprocessed_text:
                return self._rule_based_classification(description, name)
            
            # Make predictions using loaded models
            threat_pred = self.threat_model.predict([preprocessed_text])[0]
            severity_pred = self.severity_model.predict([preprocessed_text])[0]
            
            # Get prediction probabilities for confidence scores
            threat_proba = self.threat_model.predict_proba([preprocessed_text])[0]
            severity_proba = self.severity_model.predict_proba([preprocessed_text])[0]
            
            # Get confidence scores (max probability)
            threat_confidence = float(np.max(threat_proba))
            severity_confidence = float(np.max(severity_proba))
            
            return {
                'threat': threat_pred,
                'severity': severity_pred,
                'threat_confidence': threat_confidence,
                'severity_confidence': severity_confidence
            }
        except Exception as e:
            logger.error(f"Prediction error: {e}")
            return self._rule_based_classification(description, name)
    
    def _rule_based_classification(self, description, name=""):
        """Rule-based classification when ML models are not available"""
        combined_text = f"{name} {description}".lower()
        
        # Threat classification
        if any(keyword in combined_text for keyword in ['oil', 'petroleum', 'crude', 'spill', 'tanker']):
            threat = 'Oil'
        elif any(keyword in combined_text for keyword in ['chemical', 'toxic', 'hazardous', 'acid', 'industrial']):
            threat = 'Chemical'
        else:
            threat = 'Other'
        
        # Severity classification
        high_indicators = ['major', 'massive', 'large', 'explosion', 'fire', 'emergency', 'critical', 'severe']
        medium_indicators = ['moderate', 'contained', 'limited', 'minor']
        
        if any(indicator in combined_text for indicator in high_indicators):
            severity = 'high'
        elif any(indicator in combined_text for indicator in medium_indicators):
            severity = 'medium'
        else:
            severity = 'low'
        
        # Return with confidence scores for consistency
        return {
            'threat': threat,
            'severity': severity,
            'threat_confidence': 0.8,  # Mock confidence for rule-based
            'severity_confidence': 0.7
        }

# Global classifier instance
incident_classifier = IncidentClassifier()

def get_classifier():
    """Get the global classifier instance"""
    return incident_classifier

def predict_incident(description, name=""):
    """Predict threat and severity for an incident"""
    classifier = get_classifier()
    return classifier.predict(description, name)