Spaces:
Sleeping
Sleeping
| # Inference-only ML model service | |
| # Models are pre-trained and saved as .pkl files | |
| import numpy as np | |
| import re | |
| from pathlib import Path | |
| import joblib | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| # Minimal sklearn imports for model loading | |
| try: | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.ensemble import RandomForestClassifier | |
| SKLEARN_AVAILABLE = True | |
| logger.info("sklearn imported successfully") | |
| except ImportError as e: | |
| SKLEARN_AVAILABLE = False | |
| logger.warning(f"sklearn not available: {e}. Using rule-based classification.") | |
| # Get the directory where this file is located | |
| BASE_DIR = Path(__file__).resolve().parent.parent.parent | |
| MODEL_DIR = BASE_DIR / "models" | |
| MODEL_DIR.mkdir(exist_ok=True) | |
| class IncidentClassifier: | |
| def __init__(self): | |
| self.threat_model = None | |
| self.severity_model = None | |
| self.is_trained = False | |
| # Try to load pre-trained models automatically | |
| try: | |
| if self.load_models(): | |
| logger.info("Pre-trained models loaded successfully") | |
| else: | |
| logger.warning("No pre-trained models found. Classification will use fallback rules.") | |
| except Exception as e: | |
| logger.warning(f"Failed to load models on initialization: {e}") | |
| def preprocess_text(self, text): | |
| """Clean and preprocess text data""" | |
| if text is None or text == "": | |
| return "" | |
| # Convert to lowercase | |
| text = str(text).lower() | |
| # Remove special characters but keep spaces | |
| text = re.sub(r'[^a-zA-Z0-9\s]', ' ', text) | |
| # Remove extra whitespaces | |
| text = re.sub(r'\s+', ' ', text).strip() | |
| return text | |
| def load_models(self): | |
| """Load trained models from disk""" | |
| if not SKLEARN_AVAILABLE: | |
| logger.warning("sklearn not available, cannot load models") | |
| return False | |
| try: | |
| threat_model_path = MODEL_DIR / "threat_model.pkl" | |
| severity_model_path = MODEL_DIR / "severity_model.pkl" | |
| if threat_model_path.exists() and severity_model_path.exists(): | |
| self.threat_model = joblib.load(threat_model_path) | |
| self.severity_model = joblib.load(severity_model_path) | |
| self.is_trained = True | |
| logger.info("Models loaded successfully") | |
| return True | |
| else: | |
| logger.warning("Model files not found") | |
| return False | |
| except Exception as e: | |
| logger.error(f"Error loading models: {e}") | |
| return False | |
| def predict(self, description, name=""): | |
| """Predict threat type and severity for an incident""" | |
| # Combine name and description for keyword checking | |
| combined_text = f"{name} {description}".lower() | |
| # Basic keyword check for plastic - classify as Chemical threat with medium severity | |
| if 'plastic' in combined_text: | |
| logger.info("Plastic keyword detected - using basic classification") | |
| return { | |
| 'threat': 'Chemical', # Use Chemical as the threat class (as defined in model training) | |
| 'severity': 'medium', | |
| 'threat_confidence': 0.95, # High confidence for keyword match | |
| 'severity_confidence': 0.92 | |
| } | |
| if not self.is_trained: | |
| # Fallback to rule-based classification | |
| return self._rule_based_classification(description, name) | |
| try: | |
| preprocessed_text = self.preprocess_text(combined_text) | |
| if not preprocessed_text: | |
| return self._rule_based_classification(description, name) | |
| # Make predictions using loaded models | |
| threat_pred = self.threat_model.predict([preprocessed_text])[0] | |
| severity_pred = self.severity_model.predict([preprocessed_text])[0] | |
| # Get prediction probabilities for confidence scores | |
| threat_proba = self.threat_model.predict_proba([preprocessed_text])[0] | |
| severity_proba = self.severity_model.predict_proba([preprocessed_text])[0] | |
| # Get confidence scores (max probability) | |
| threat_confidence = float(np.max(threat_proba)) | |
| severity_confidence = float(np.max(severity_proba)) | |
| return { | |
| 'threat': threat_pred, | |
| 'severity': severity_pred, | |
| 'threat_confidence': threat_confidence, | |
| 'severity_confidence': severity_confidence | |
| } | |
| except Exception as e: | |
| logger.error(f"Prediction error: {e}") | |
| return self._rule_based_classification(description, name) | |
| def _rule_based_classification(self, description, name=""): | |
| """Rule-based classification when ML models are not available""" | |
| combined_text = f"{name} {description}".lower() | |
| # Threat classification | |
| if any(keyword in combined_text for keyword in ['oil', 'petroleum', 'crude', 'spill', 'tanker']): | |
| threat = 'Oil' | |
| elif any(keyword in combined_text for keyword in ['chemical', 'toxic', 'hazardous', 'acid', 'industrial']): | |
| threat = 'Chemical' | |
| else: | |
| threat = 'Other' | |
| # Severity classification | |
| high_indicators = ['major', 'massive', 'large', 'explosion', 'fire', 'emergency', 'critical', 'severe'] | |
| medium_indicators = ['moderate', 'contained', 'limited', 'minor'] | |
| if any(indicator in combined_text for indicator in high_indicators): | |
| severity = 'high' | |
| elif any(indicator in combined_text for indicator in medium_indicators): | |
| severity = 'medium' | |
| else: | |
| severity = 'low' | |
| # Return with confidence scores for consistency | |
| return { | |
| 'threat': threat, | |
| 'severity': severity, | |
| 'threat_confidence': 0.8, # Mock confidence for rule-based | |
| 'severity_confidence': 0.7 | |
| } | |
| # Global classifier instance | |
| incident_classifier = IncidentClassifier() | |
| def get_classifier(): | |
| """Get the global classifier instance""" | |
| return incident_classifier | |
| def predict_incident(description, name=""): | |
| """Predict threat and severity for an incident""" | |
| classifier = get_classifier() | |
| return classifier.predict(description, name) |