audio-dashboard / src /predict.py
lawlevisan's picture
Update src/predict.py
737ac7f verified
# predict.py - Optimized Production Version with Enhanced Accuracy & Performance
from transformers import (
DistilBertTokenizerFast,
DistilBertForSequenceClassification,
DistilBertConfig,
AutoTokenizer,
AutoModelForSequenceClassification
)
import torch
import torch.nn.functional as F
import numpy as np
import logging
import os
import json
import shutil
import re
import time
from typing import Tuple, List, Optional, Dict
from functools import lru_cache
import threading
# =======================
# Logging configuration
# =======================
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# =======================
# Global variables with thread safety
# =======================
model = None
tokenizer = None
model_loaded = False
model_lock = threading.Lock()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
# Performance tracking
inference_times = []
# Load slang keywords with caching
DRUG_KEYWORDS = []
HIGH_RISK_KEYWORDS = []
# ========================
# Enhanced keyword loading with caching
# ========================
@lru_cache(maxsize=1)
def load_keywords(file_path="slang_keywords.txt") -> List[str]:
"""Load keywords with caching for performance"""
if not os.path.exists(file_path):
logger.warning(f"Keyword file not found: {file_path}. Using default keywords.")
return []
try:
with open(file_path, "r", encoding='utf-8') as f:
keywords = [line.strip().lower() for line in f if line.strip() and not line.startswith('#')]
logger.info(f"Loaded {len(keywords)} slang keywords from {file_path}")
return keywords
except Exception as e:
logger.error(f"Failed to load keywords from {file_path}: {e}")
return []
@lru_cache(maxsize=1)
def load_high_risk_keywords(file_path="high_risk_keywords.txt") -> List[str]:
"""Load high-risk keywords with caching"""
if not os.path.exists(file_path):
logger.warning(f"High-risk keyword file not found: {file_path}")
return ["cocaine", "heroin", "mdma", "lsd", "meth", "fentanyl", "dealer", "supplier"]
try:
with open(file_path, "r", encoding='utf-8') as f:
keywords = [line.strip().lower() for line in f if line.strip() and not line.startswith('#')]
logger.info(f"Loaded {len(keywords)} high-risk keywords from {file_path}")
return keywords
except Exception as e:
logger.error(f"Failed to load high-risk keywords: {e}")
return ["cocaine", "heroin", "mdma", "lsd", "meth", "fentanyl", "dealer", "supplier"]
# Initialize global keywords
DRUG_KEYWORDS = load_keywords("slang_keywords.txt")
HIGH_RISK_KEYWORDS = load_high_risk_keywords("high_risk_keywords.txt")
# =======================
# Enhanced text preprocessing for better accuracy
# =======================
def preprocess_text(text: str) -> str:
"""Enhanced text preprocessing for better model accuracy"""
if not text:
return ""
# Convert to lowercase
text = text.lower()
# Remove excessive whitespace but preserve sentence structure
text = re.sub(r'\s+', ' ', text)
# Handle common abbreviations and slang normalization
abbreviations = {
'u': 'you',
'ur': 'your',
'n': 'and',
'w/': 'with',
'thru': 'through',
'gonna': 'going to',
'wanna': 'want to',
'gotta': 'got to'
}
for abbrev, full in abbreviations.items():
text = re.sub(rf'\b{re.escape(abbrev)}\b', full, text)
# Remove excessive punctuation but keep sentence boundaries
text = re.sub(r'[!]{2,}', '!', text)
text = re.sub(r'[?]{2,}', '?', text)
text = re.sub(r'[.]{3,}', '...', text)
return text.strip()
# =======================
# Enhanced keyword-based scoring
# =======================
def compute_keyword_score(text: str) -> Tuple[float, Dict[str, int]]:
"""Compute keyword-based score for enhanced accuracy"""
text_lower = text.lower()
# =======================
# Enhanced text preprocessing for better accuracy
# =======================
def preprocess_text(text: str) -> str:
"""Enhanced text preprocessing for better model accuracy"""
if not text:
return ""
# Convert to lowercase
text = text.lower()
# Remove excessive whitespace but preserve sentence structure
text = re.sub(r'\s+', ' ', text)
# Handle common abbreviations and slang normalization
abbreviations = {
'u': 'you',
'ur': 'your',
'n': 'and',
'w/': 'with',
'thru': 'through',
'gonna': 'going to',
'wanna': 'want to',
'gotta': 'got to'
}
for abbrev, full in abbreviations.items():
text = re.sub(rf'\b{re.escape(abbrev)}\b', full, text)
# Remove excessive punctuation but keep sentence boundaries
text = re.sub(r'[!]{2,}', '!', text)
text = re.sub(r'[?]{2,}', '?', text)
text = re.sub(r'[.]{3,}', '...', text)
return text.strip()
# =======================
# Enhanced keyword-based scoring
# =======================
def compute_keyword_score(text: str) -> Tuple[float, Dict[str, int]]:
"""Compute keyword-based score for enhanced accuracy"""
text_lower = text.lower()
AMBIGUOUS_TERMS = {"e", "x", "line", "ice", "horse", "420"}
def keyword_check_with_context(text: str, kw: str) -> bool:
pattern = rf"\b{re.escape(kw)}\b"
if re.search(pattern, text, re.IGNORECASE):
if kw in AMBIGUOUS_TERMS:
context_pattern = r"\b(smoke|roll|pop|hit|take|buy|sell|party|snort|inject)\b"
return bool(re.search(context_pattern, text, re.IGNORECASE))
return True
return False
def compute_keyword_score(text: str) -> Tuple[float, Dict[str, int]]:
"""Compute keyword-based score for enhanced accuracy"""
text_lower = text.lower()
drug_matches = sum(1 for kw in DRUG_KEYWORDS if keyword_check_with_context(text_lower, kw))
high_risk_matches = sum(1 for kw in HIGH_RISK_KEYWORDS if keyword_check_with_context(text_lower, kw))
context_patterns = [
r'(?i)(pick.*up|got.*stuff|meet.*behind)',
r'(?i)(payment|crypto|cash.*deal)',
r'(?i)(supplier|dealer|connect)',
r'(?i)(party.*saturday|rave.*tonight)',
r'(?i)(quality.*good|pure.*stuff)',
r'(?i)(cops.*around|too.*risky)'
]
context_matches = sum(1 for pattern in context_patterns if re.search(pattern, text_lower))
keyword_score = 0.0
if high_risk_matches > 0:
keyword_score += min(high_risk_matches * 0.3, 0.8)
if drug_matches > 0:
keyword_score += min(drug_matches * 0.1, 0.3)
if context_matches > 0:
keyword_score += min(context_matches * 0.15, 0.4)
keyword_score = min(keyword_score, 1.0)
return keyword_score, {
'drug_keywords': drug_matches,
'high_risk_keywords': high_risk_matches,
'context_patterns': context_matches
}
# =======================
# Config validation/fix with enhanced error handling
# =======================
def validate_and_fix_config(model_path: str) -> bool:
"""Validate and fix model configuration if needed"""
config_path = os.path.join(model_path, "config.json")
if not os.path.exists(config_path):
logger.warning(f"Config file not found at {config_path}")
return False
try:
with open(config_path, 'r', encoding='utf-8') as f:
config_data = json.load(f)
# Validate critical dimensions
dim = config_data.get('dim', 768)
n_heads = config_data.get('n_heads', 12)
if dim % n_heads != 0:
logger.warning(f"Configuration issue detected: dim={dim} not divisible by n_heads={n_heads}")
# Create backup
backup_path = config_path + ".backup"
if not os.path.exists(backup_path):
shutil.copy2(config_path, backup_path)
logger.info(f"Backed up original config to {backup_path}")
# Fix configuration with standard DistilBERT dimensions
config_data.update({
'dim': 768,
'n_heads': 12,
'hidden_dim': 3072,
'n_layers': 6,
'vocab_size': 30522,
'max_position_embeddings': 512,
'dropout': 0.1,
'attention_dropout': 0.1,
'activation': 'gelu',
'num_labels': 2
})
with open(config_path, 'w', encoding='utf-8') as f:
json.dump(config_data, f, indent=2)
logger.info("Fixed configuration with standard DistilBERT dimensions")
logger.info("Configuration validation completed")
return True
except Exception as e:
logger.error(f"Error validating/fixing config: {e}")
return False
# =======================
# Enhanced model loading with multiple fallback strategies
# =======================
def load_model_with_fallback(model_name: str) -> bool:
"""Use standard model - bypass custom model for now"""
global model, tokenizer, model_loaded
with model_lock:
if model_loaded:
return True
logger.info("Using standard DistilBERT model (custom model has tokenizer issues)")
try:
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
model = AutoModelForSequenceClassification.from_pretrained(
'distilbert-base-uncased',
num_labels=2
)
model.to(device)
model.eval()
model_loaded = True
logger.info("Standard model loaded successfully")
return True
except Exception as e:
logger.error(f"Model loading failed: {e}")
return False
# =======================
# Optimized prediction function with enhanced accuracy
# =======================
def predict(text: str, return_confidence: bool = True) -> Tuple[int, float]:
"""
Enhanced prediction with improved accuracy and performance
Args:
text: Input text to classify
return_confidence: Whether to return confidence scores
Returns:
Tuple of (prediction_label, confidence_score)
"""
start_time = time.time()
try:
# Input validation
if not text or not text.strip():
logger.warning("Empty input text")
return 0, 0.0
# Check if model is loaded
if not model_loaded or model is None or tokenizer is None:
logger.error("Model not loaded properly")
return 0, 0.0
# Preprocess text for better accuracy
processed_text = preprocess_text(text)
# Get keyword-based score
keyword_score, keyword_stats = compute_keyword_score(processed_text)
# Tokenize with proper handling
try:
inputs = tokenizer(
processed_text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512,
add_special_tokens=True
)
inputs = {k: v.to(device) for k, v in inputs.items()}
except Exception as e:
logger.error(f"Tokenization failed: {e}")
return 0, 0.0
# Model inference with no_grad for performance
with torch.no_grad():
try:
outputs = model(**inputs)
logits = outputs.logits
# Apply softmax to get probabilities
probabilities = F.softmax(logits, dim=-1)
ml_confidence = probabilities[0][1].item() # Probability of drug class
ml_prediction = int(ml_confidence > 0.5)
except Exception as e:
logger.error(f"Model inference failed: {e}")
return 0, 0.0
# Enhanced decision making combining ML and keywords
final_prediction, final_confidence = combine_predictions(
ml_prediction, ml_confidence, keyword_score, keyword_stats, processed_text
)
# Log performance metrics
inference_time = time.time() - start_time
inference_times.append(inference_time)
# Keep only last 100 timing records
if len(inference_times) > 100:
inference_times.pop(0)
logger.info(f"Prediction completed in {inference_time:.3f}s - "
f"Result: {'DRUG' if final_prediction == 1 else 'NON_DRUG'} "
f"(confidence: {final_confidence:.3f}, keyword_score: {keyword_score:.3f})")
return final_prediction, final_confidence
except Exception as e:
logger.error(f"Prediction failed: {e}")
return 0, 0.0
# =======================
# Enhanced prediction combination logic
# =======================
def combine_predictions(ml_pred: int, ml_conf: float, keyword_score: float,
keyword_stats: Dict[str, int], text: str) -> Tuple[int, float]:
"""
Combine ML prediction with keyword-based scoring for better accuracy
"""
try:
# Weight calculation based on keyword evidence
high_risk_count = keyword_stats.get('high_risk_keywords', 0)
drug_count = keyword_stats.get('drug_keywords', 0)
context_count = keyword_stats.get('context_patterns', 0)
# Determine weights based on keyword strength
if high_risk_count >= 2:
ml_weight, keyword_weight = 0.2, 0.8
elif high_risk_count >= 1 or drug_count >= 3:
ml_weight, keyword_weight = 0.3, 0.7
elif drug_count >= 2 or context_count >= 2:
ml_weight, keyword_weight = 0.4, 0.6
else:
ml_weight, keyword_weight = 0.7, 0.3
# Combine scores
combined_score = (ml_weight * ml_conf) + (keyword_weight * keyword_score)
# Enhanced decision logic
if high_risk_count >= 1:
# High-risk keywords present - likely drug content
final_pred = 1
final_conf = max(combined_score, 0.7)
elif keyword_score >= 0.5:
# Strong keyword evidence
final_pred = 1
final_conf = combined_score
elif keyword_score >= 0.3 and ml_conf >= 0.3:
# Moderate evidence from both
final_pred = 1
final_conf = combined_score
elif ml_conf >= 0.7:
# High ML confidence
final_pred = 1
final_conf = combined_score
else:
# Low confidence overall
final_pred = 0
final_conf = max(combined_score, 0.1)
# Ensure confidence is in valid range
final_conf = max(0.0, min(1.0, final_conf))
return final_pred, final_conf
except Exception as e:
logger.error(f"Prediction combination failed: {e}")
return ml_pred, ml_conf
# =======================
# Model management functions
# =======================
def load_model(model_path: str) -> bool:
"""Load model with enhanced error handling"""
try:
success = load_model_with_fallback(model_path)
if success:
logger.info(f"Model loaded successfully from {model_path}")
# Log model info
if model:
param_count = sum(p.numel() for p in model.parameters())
logger.info(f"Model parameters: {param_count:,}")
logger.info(f"Model device: {next(model.parameters()).device}")
else:
logger.error(f"Failed to load model from {model_path}")
return success
except Exception as e:
logger.error(f"Model loading error: {e}")
return False
def is_model_loaded() -> bool:
"""Check if model is properly loaded"""
return model_loaded and model is not None and tokenizer is not None
def get_model_info() -> Dict:
"""Get information about the loaded model"""
if not is_model_loaded():
return {"status": "not_loaded"}
try:
param_count = sum(p.numel() for p in model.parameters())
avg_inference_time = np.mean(inference_times) if inference_times else 0.0
return {
"status": "loaded",
"model_type": type(model).__name__,
"tokenizer_type": type(tokenizer).__name__,
"device": str(device),
"parameters": param_count,
"avg_inference_time": avg_inference_time,
"total_predictions": len(inference_times),
"drug_keywords_count": len(DRUG_KEYWORDS),
"high_risk_keywords_count": len(HIGH_RISK_KEYWORDS)
}
except Exception as e:
logger.error(f"Error getting model info: {e}")
return {"status": "error", "error": str(e)}
# =======================
# Batch prediction for performance
# =======================
def predict_batch(texts: List[str], batch_size: int = 8) -> List[Tuple[int, float]]:
"""
Batch prediction for improved performance on multiple texts
"""
if not texts:
return []
if not is_model_loaded():
logger.error("Model not loaded for batch prediction")
return [(0, 0.0) for _ in texts]
results = []
try:
# Process in batches
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i + batch_size]
batch_results = []
# Process each text in the batch
for text in batch_texts:
pred, conf = predict(text)
batch_results.append((pred, conf))
results.extend(batch_results)
logger.info(f"Batch prediction completed for {len(texts)} texts")
return results
except Exception as e:
logger.error(f"Batch prediction failed: {e}")
return [(0, 0.0) for _ in texts]
# =======================
# Performance monitoring
# =======================
def get_performance_stats() -> Dict:
"""Get performance statistics"""
if not inference_times:
return {"status": "no_data"}
return {
"total_predictions": len(inference_times),
"avg_inference_time": np.mean(inference_times),
"min_inference_time": min(inference_times),
"max_inference_time": max(inference_times),
"std_inference_time": np.std(inference_times),
"device": str(device)
}
# =======================
# Module initialization
# =======================
def initialize_model(model_path: str = None) -> bool:
"""Initialize the prediction module"""
if model_path:
return load_model(model_path)
return False
# =======================
# Main execution for testing
# =======================
if __name__ == "__main__":
# Test the prediction system
test_texts = [
"Hey, can you pick up some stuff from behind the metro station?",
"I'm going to the grocery store to buy some milk and bread.",
"The quality is really good this time, payment through crypto as usual.",
"Let's meet for coffee tomorrow morning at 9 AM."
]
print("Testing prediction system...")
for i, text in enumerate(test_texts):
pred, conf = predict(text)
result = "DRUG" if pred == 1 else "NON_DRUG"
print(f"Text {i+1}: {result} (confidence: {conf:.3f})")
print(f" Input: {text}")
print()
# Print performance stats
stats = get_performance_stats()
print("Performance Stats:", stats)
# Print model info
info = get_model_info()
print("Model Info:", info)