Sanjay / app /services /ml_model.py
TheDeepDas's picture
Final
8ff3e07
# Inference-only ML model service
# Models are pre-trained and saved as .pkl files
import numpy as np
import re
from pathlib import Path
import joblib
import logging
logger = logging.getLogger(__name__)
# Minimal sklearn imports for model loading
try:
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.ensemble import RandomForestClassifier
SKLEARN_AVAILABLE = True
logger.info("sklearn imported successfully")
except ImportError as e:
SKLEARN_AVAILABLE = False
logger.warning(f"sklearn not available: {e}. Using rule-based classification.")
# Get the directory where this file is located
BASE_DIR = Path(__file__).resolve().parent.parent.parent
MODEL_DIR = BASE_DIR / "models"
MODEL_DIR.mkdir(exist_ok=True)
class IncidentClassifier:
def __init__(self):
self.threat_model = None
self.severity_model = None
self.is_trained = False
# Try to load pre-trained models automatically
try:
if self.load_models():
logger.info("Pre-trained models loaded successfully")
else:
logger.warning("No pre-trained models found. Classification will use fallback rules.")
except Exception as e:
logger.warning(f"Failed to load models on initialization: {e}")
def preprocess_text(self, text):
"""Clean and preprocess text data"""
if text is None or text == "":
return ""
# Convert to lowercase
text = str(text).lower()
# Remove special characters but keep spaces
text = re.sub(r'[^a-zA-Z0-9\s]', ' ', text)
# Remove extra whitespaces
text = re.sub(r'\s+', ' ', text).strip()
return text
def load_models(self):
"""Load trained models from disk"""
if not SKLEARN_AVAILABLE:
logger.warning("sklearn not available, cannot load models")
return False
try:
threat_model_path = MODEL_DIR / "threat_model.pkl"
severity_model_path = MODEL_DIR / "severity_model.pkl"
if threat_model_path.exists() and severity_model_path.exists():
self.threat_model = joblib.load(threat_model_path)
self.severity_model = joblib.load(severity_model_path)
self.is_trained = True
logger.info("Models loaded successfully")
return True
else:
logger.warning("Model files not found")
return False
except Exception as e:
logger.error(f"Error loading models: {e}")
return False
def predict(self, description, name=""):
"""Predict threat type and severity for an incident"""
# Combine name and description for keyword checking
combined_text = f"{name} {description}".lower()
# Basic keyword check for plastic - classify as Chemical threat with medium severity
if 'plastic' in combined_text:
logger.info("Plastic keyword detected - using basic classification")
return {
'threat': 'Chemical', # Use Chemical as the threat class (as defined in model training)
'severity': 'medium',
'threat_confidence': 0.95, # High confidence for keyword match
'severity_confidence': 0.92
}
if not self.is_trained:
# Fallback to rule-based classification
return self._rule_based_classification(description, name)
try:
preprocessed_text = self.preprocess_text(combined_text)
if not preprocessed_text:
return self._rule_based_classification(description, name)
# Make predictions using loaded models
threat_pred = self.threat_model.predict([preprocessed_text])[0]
severity_pred = self.severity_model.predict([preprocessed_text])[0]
# Get prediction probabilities for confidence scores
threat_proba = self.threat_model.predict_proba([preprocessed_text])[0]
severity_proba = self.severity_model.predict_proba([preprocessed_text])[0]
# Get confidence scores (max probability)
threat_confidence = float(np.max(threat_proba))
severity_confidence = float(np.max(severity_proba))
return {
'threat': threat_pred,
'severity': severity_pred,
'threat_confidence': threat_confidence,
'severity_confidence': severity_confidence
}
except Exception as e:
logger.error(f"Prediction error: {e}")
return self._rule_based_classification(description, name)
def _rule_based_classification(self, description, name=""):
"""Rule-based classification when ML models are not available"""
combined_text = f"{name} {description}".lower()
# Threat classification
if any(keyword in combined_text for keyword in ['oil', 'petroleum', 'crude', 'spill', 'tanker']):
threat = 'Oil'
elif any(keyword in combined_text for keyword in ['chemical', 'toxic', 'hazardous', 'acid', 'industrial']):
threat = 'Chemical'
else:
threat = 'Other'
# Severity classification
high_indicators = ['major', 'massive', 'large', 'explosion', 'fire', 'emergency', 'critical', 'severe']
medium_indicators = ['moderate', 'contained', 'limited', 'minor']
if any(indicator in combined_text for indicator in high_indicators):
severity = 'high'
elif any(indicator in combined_text for indicator in medium_indicators):
severity = 'medium'
else:
severity = 'low'
# Return with confidence scores for consistency
return {
'threat': threat,
'severity': severity,
'threat_confidence': 0.8, # Mock confidence for rule-based
'severity_confidence': 0.7
}
# Global classifier instance
incident_classifier = IncidentClassifier()
def get_classifier():
"""Get the global classifier instance"""
return incident_classifier
def predict_incident(description, name=""):
"""Predict threat and severity for an incident"""
classifier = get_classifier()
return classifier.predict(description, name)