Spaces:
Sleeping
Sleeping
| # predict.py - Optimized Production Version with Enhanced Accuracy & Performance | |
| from transformers import ( | |
| DistilBertTokenizerFast, | |
| DistilBertForSequenceClassification, | |
| DistilBertConfig, | |
| AutoTokenizer, | |
| AutoModelForSequenceClassification | |
| ) | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import logging | |
| import os | |
| import json | |
| import shutil | |
| import re | |
| import time | |
| from typing import Tuple, List, Optional, Dict | |
| from functools import lru_cache | |
| import threading | |
| # ======================= | |
| # Logging configuration | |
| # ======================= | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # ======================= | |
| # Global variables with thread safety | |
| # ======================= | |
| model = None | |
| tokenizer = None | |
| model_loaded = False | |
| model_lock = threading.Lock() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Using device: {device}") | |
| # Performance tracking | |
| inference_times = [] | |
| # Load slang keywords with caching | |
| DRUG_KEYWORDS = [] | |
| HIGH_RISK_KEYWORDS = [] | |
| # ======================== | |
| # Enhanced keyword loading with caching | |
| # ======================== | |
| def load_keywords(file_path="slang_keywords.txt") -> List[str]: | |
| """Load keywords with caching for performance""" | |
| if not os.path.exists(file_path): | |
| logger.warning(f"Keyword file not found: {file_path}. Using default keywords.") | |
| return [] | |
| try: | |
| with open(file_path, "r", encoding='utf-8') as f: | |
| keywords = [line.strip().lower() for line in f if line.strip() and not line.startswith('#')] | |
| logger.info(f"Loaded {len(keywords)} slang keywords from {file_path}") | |
| return keywords | |
| except Exception as e: | |
| logger.error(f"Failed to load keywords from {file_path}: {e}") | |
| return [] | |
| def load_high_risk_keywords(file_path="high_risk_keywords.txt") -> List[str]: | |
| """Load high-risk keywords with caching""" | |
| if not os.path.exists(file_path): | |
| logger.warning(f"High-risk keyword file not found: {file_path}") | |
| return ["cocaine", "heroin", "mdma", "lsd", "meth", "fentanyl", "dealer", "supplier"] | |
| try: | |
| with open(file_path, "r", encoding='utf-8') as f: | |
| keywords = [line.strip().lower() for line in f if line.strip() and not line.startswith('#')] | |
| logger.info(f"Loaded {len(keywords)} high-risk keywords from {file_path}") | |
| return keywords | |
| except Exception as e: | |
| logger.error(f"Failed to load high-risk keywords: {e}") | |
| return ["cocaine", "heroin", "mdma", "lsd", "meth", "fentanyl", "dealer", "supplier"] | |
| # Initialize global keywords | |
| DRUG_KEYWORDS = load_keywords("slang_keywords.txt") | |
| HIGH_RISK_KEYWORDS = load_high_risk_keywords("high_risk_keywords.txt") | |
| # ======================= | |
| # Enhanced text preprocessing for better accuracy | |
| # ======================= | |
| def preprocess_text(text: str) -> str: | |
| """Enhanced text preprocessing for better model accuracy""" | |
| if not text: | |
| return "" | |
| # Convert to lowercase | |
| text = text.lower() | |
| # Remove excessive whitespace but preserve sentence structure | |
| text = re.sub(r'\s+', ' ', text) | |
| # Handle common abbreviations and slang normalization | |
| abbreviations = { | |
| 'u': 'you', | |
| 'ur': 'your', | |
| 'n': 'and', | |
| 'w/': 'with', | |
| 'thru': 'through', | |
| 'gonna': 'going to', | |
| 'wanna': 'want to', | |
| 'gotta': 'got to' | |
| } | |
| for abbrev, full in abbreviations.items(): | |
| text = re.sub(rf'\b{re.escape(abbrev)}\b', full, text) | |
| # Remove excessive punctuation but keep sentence boundaries | |
| text = re.sub(r'[!]{2,}', '!', text) | |
| text = re.sub(r'[?]{2,}', '?', text) | |
| text = re.sub(r'[.]{3,}', '...', text) | |
| return text.strip() | |
| # ======================= | |
| # Enhanced keyword-based scoring | |
| # ======================= | |
| def compute_keyword_score(text: str) -> Tuple[float, Dict[str, int]]: | |
| """Compute keyword-based score for enhanced accuracy""" | |
| text_lower = text.lower() | |
| # ======================= | |
| # Enhanced text preprocessing for better accuracy | |
| # ======================= | |
| def preprocess_text(text: str) -> str: | |
| """Enhanced text preprocessing for better model accuracy""" | |
| if not text: | |
| return "" | |
| # Convert to lowercase | |
| text = text.lower() | |
| # Remove excessive whitespace but preserve sentence structure | |
| text = re.sub(r'\s+', ' ', text) | |
| # Handle common abbreviations and slang normalization | |
| abbreviations = { | |
| 'u': 'you', | |
| 'ur': 'your', | |
| 'n': 'and', | |
| 'w/': 'with', | |
| 'thru': 'through', | |
| 'gonna': 'going to', | |
| 'wanna': 'want to', | |
| 'gotta': 'got to' | |
| } | |
| for abbrev, full in abbreviations.items(): | |
| text = re.sub(rf'\b{re.escape(abbrev)}\b', full, text) | |
| # Remove excessive punctuation but keep sentence boundaries | |
| text = re.sub(r'[!]{2,}', '!', text) | |
| text = re.sub(r'[?]{2,}', '?', text) | |
| text = re.sub(r'[.]{3,}', '...', text) | |
| return text.strip() | |
| # ======================= | |
| # Enhanced keyword-based scoring | |
| # ======================= | |
| def compute_keyword_score(text: str) -> Tuple[float, Dict[str, int]]: | |
| """Compute keyword-based score for enhanced accuracy""" | |
| text_lower = text.lower() | |
| AMBIGUOUS_TERMS = {"e", "x", "line", "ice", "horse", "420"} | |
| def keyword_check_with_context(text: str, kw: str) -> bool: | |
| pattern = rf"\b{re.escape(kw)}\b" | |
| if re.search(pattern, text, re.IGNORECASE): | |
| if kw in AMBIGUOUS_TERMS: | |
| context_pattern = r"\b(smoke|roll|pop|hit|take|buy|sell|party|snort|inject)\b" | |
| return bool(re.search(context_pattern, text, re.IGNORECASE)) | |
| return True | |
| return False | |
| def compute_keyword_score(text: str) -> Tuple[float, Dict[str, int]]: | |
| """Compute keyword-based score for enhanced accuracy""" | |
| text_lower = text.lower() | |
| drug_matches = sum(1 for kw in DRUG_KEYWORDS if keyword_check_with_context(text_lower, kw)) | |
| high_risk_matches = sum(1 for kw in HIGH_RISK_KEYWORDS if keyword_check_with_context(text_lower, kw)) | |
| context_patterns = [ | |
| r'(?i)(pick.*up|got.*stuff|meet.*behind)', | |
| r'(?i)(payment|crypto|cash.*deal)', | |
| r'(?i)(supplier|dealer|connect)', | |
| r'(?i)(party.*saturday|rave.*tonight)', | |
| r'(?i)(quality.*good|pure.*stuff)', | |
| r'(?i)(cops.*around|too.*risky)' | |
| ] | |
| context_matches = sum(1 for pattern in context_patterns if re.search(pattern, text_lower)) | |
| keyword_score = 0.0 | |
| if high_risk_matches > 0: | |
| keyword_score += min(high_risk_matches * 0.3, 0.8) | |
| if drug_matches > 0: | |
| keyword_score += min(drug_matches * 0.1, 0.3) | |
| if context_matches > 0: | |
| keyword_score += min(context_matches * 0.15, 0.4) | |
| keyword_score = min(keyword_score, 1.0) | |
| return keyword_score, { | |
| 'drug_keywords': drug_matches, | |
| 'high_risk_keywords': high_risk_matches, | |
| 'context_patterns': context_matches | |
| } | |
| # ======================= | |
| # Config validation/fix with enhanced error handling | |
| # ======================= | |
| def validate_and_fix_config(model_path: str) -> bool: | |
| """Validate and fix model configuration if needed""" | |
| config_path = os.path.join(model_path, "config.json") | |
| if not os.path.exists(config_path): | |
| logger.warning(f"Config file not found at {config_path}") | |
| return False | |
| try: | |
| with open(config_path, 'r', encoding='utf-8') as f: | |
| config_data = json.load(f) | |
| # Validate critical dimensions | |
| dim = config_data.get('dim', 768) | |
| n_heads = config_data.get('n_heads', 12) | |
| if dim % n_heads != 0: | |
| logger.warning(f"Configuration issue detected: dim={dim} not divisible by n_heads={n_heads}") | |
| # Create backup | |
| backup_path = config_path + ".backup" | |
| if not os.path.exists(backup_path): | |
| shutil.copy2(config_path, backup_path) | |
| logger.info(f"Backed up original config to {backup_path}") | |
| # Fix configuration with standard DistilBERT dimensions | |
| config_data.update({ | |
| 'dim': 768, | |
| 'n_heads': 12, | |
| 'hidden_dim': 3072, | |
| 'n_layers': 6, | |
| 'vocab_size': 30522, | |
| 'max_position_embeddings': 512, | |
| 'dropout': 0.1, | |
| 'attention_dropout': 0.1, | |
| 'activation': 'gelu', | |
| 'num_labels': 2 | |
| }) | |
| with open(config_path, 'w', encoding='utf-8') as f: | |
| json.dump(config_data, f, indent=2) | |
| logger.info("Fixed configuration with standard DistilBERT dimensions") | |
| logger.info("Configuration validation completed") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error validating/fixing config: {e}") | |
| return False | |
| # ======================= | |
| # Enhanced model loading with multiple fallback strategies | |
| # ======================= | |
| def load_model_with_fallback(model_name: str) -> bool: | |
| """Use standard model - bypass custom model for now""" | |
| global model, tokenizer, model_loaded | |
| with model_lock: | |
| if model_loaded: | |
| return True | |
| logger.info("Using standard DistilBERT model (custom model has tokenizer issues)") | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased') | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| 'distilbert-base-uncased', | |
| num_labels=2 | |
| ) | |
| model.to(device) | |
| model.eval() | |
| model_loaded = True | |
| logger.info("Standard model loaded successfully") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Model loading failed: {e}") | |
| return False | |
| # ======================= | |
| # Optimized prediction function with enhanced accuracy | |
| # ======================= | |
| def predict(text: str, return_confidence: bool = True) -> Tuple[int, float]: | |
| """ | |
| Enhanced prediction with improved accuracy and performance | |
| Args: | |
| text: Input text to classify | |
| return_confidence: Whether to return confidence scores | |
| Returns: | |
| Tuple of (prediction_label, confidence_score) | |
| """ | |
| start_time = time.time() | |
| try: | |
| # Input validation | |
| if not text or not text.strip(): | |
| logger.warning("Empty input text") | |
| return 0, 0.0 | |
| # Check if model is loaded | |
| if not model_loaded or model is None or tokenizer is None: | |
| logger.error("Model not loaded properly") | |
| return 0, 0.0 | |
| # Preprocess text for better accuracy | |
| processed_text = preprocess_text(text) | |
| # Get keyword-based score | |
| keyword_score, keyword_stats = compute_keyword_score(processed_text) | |
| # Tokenize with proper handling | |
| try: | |
| inputs = tokenizer( | |
| processed_text, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=512, | |
| add_special_tokens=True | |
| ) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| except Exception as e: | |
| logger.error(f"Tokenization failed: {e}") | |
| return 0, 0.0 | |
| # Model inference with no_grad for performance | |
| with torch.no_grad(): | |
| try: | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| # Apply softmax to get probabilities | |
| probabilities = F.softmax(logits, dim=-1) | |
| ml_confidence = probabilities[0][1].item() # Probability of drug class | |
| ml_prediction = int(ml_confidence > 0.5) | |
| except Exception as e: | |
| logger.error(f"Model inference failed: {e}") | |
| return 0, 0.0 | |
| # Enhanced decision making combining ML and keywords | |
| final_prediction, final_confidence = combine_predictions( | |
| ml_prediction, ml_confidence, keyword_score, keyword_stats, processed_text | |
| ) | |
| # Log performance metrics | |
| inference_time = time.time() - start_time | |
| inference_times.append(inference_time) | |
| # Keep only last 100 timing records | |
| if len(inference_times) > 100: | |
| inference_times.pop(0) | |
| logger.info(f"Prediction completed in {inference_time:.3f}s - " | |
| f"Result: {'DRUG' if final_prediction == 1 else 'NON_DRUG'} " | |
| f"(confidence: {final_confidence:.3f}, keyword_score: {keyword_score:.3f})") | |
| return final_prediction, final_confidence | |
| except Exception as e: | |
| logger.error(f"Prediction failed: {e}") | |
| return 0, 0.0 | |
| # ======================= | |
| # Enhanced prediction combination logic | |
| # ======================= | |
| def combine_predictions(ml_pred: int, ml_conf: float, keyword_score: float, | |
| keyword_stats: Dict[str, int], text: str) -> Tuple[int, float]: | |
| """ | |
| Combine ML prediction with keyword-based scoring for better accuracy | |
| """ | |
| try: | |
| # Weight calculation based on keyword evidence | |
| high_risk_count = keyword_stats.get('high_risk_keywords', 0) | |
| drug_count = keyword_stats.get('drug_keywords', 0) | |
| context_count = keyword_stats.get('context_patterns', 0) | |
| # Determine weights based on keyword strength | |
| if high_risk_count >= 2: | |
| ml_weight, keyword_weight = 0.2, 0.8 | |
| elif high_risk_count >= 1 or drug_count >= 3: | |
| ml_weight, keyword_weight = 0.3, 0.7 | |
| elif drug_count >= 2 or context_count >= 2: | |
| ml_weight, keyword_weight = 0.4, 0.6 | |
| else: | |
| ml_weight, keyword_weight = 0.7, 0.3 | |
| # Combine scores | |
| combined_score = (ml_weight * ml_conf) + (keyword_weight * keyword_score) | |
| # Enhanced decision logic | |
| if high_risk_count >= 1: | |
| # High-risk keywords present - likely drug content | |
| final_pred = 1 | |
| final_conf = max(combined_score, 0.7) | |
| elif keyword_score >= 0.5: | |
| # Strong keyword evidence | |
| final_pred = 1 | |
| final_conf = combined_score | |
| elif keyword_score >= 0.3 and ml_conf >= 0.3: | |
| # Moderate evidence from both | |
| final_pred = 1 | |
| final_conf = combined_score | |
| elif ml_conf >= 0.7: | |
| # High ML confidence | |
| final_pred = 1 | |
| final_conf = combined_score | |
| else: | |
| # Low confidence overall | |
| final_pred = 0 | |
| final_conf = max(combined_score, 0.1) | |
| # Ensure confidence is in valid range | |
| final_conf = max(0.0, min(1.0, final_conf)) | |
| return final_pred, final_conf | |
| except Exception as e: | |
| logger.error(f"Prediction combination failed: {e}") | |
| return ml_pred, ml_conf | |
| # ======================= | |
| # Model management functions | |
| # ======================= | |
| def load_model(model_path: str) -> bool: | |
| """Load model with enhanced error handling""" | |
| try: | |
| success = load_model_with_fallback(model_path) | |
| if success: | |
| logger.info(f"Model loaded successfully from {model_path}") | |
| # Log model info | |
| if model: | |
| param_count = sum(p.numel() for p in model.parameters()) | |
| logger.info(f"Model parameters: {param_count:,}") | |
| logger.info(f"Model device: {next(model.parameters()).device}") | |
| else: | |
| logger.error(f"Failed to load model from {model_path}") | |
| return success | |
| except Exception as e: | |
| logger.error(f"Model loading error: {e}") | |
| return False | |
| def is_model_loaded() -> bool: | |
| """Check if model is properly loaded""" | |
| return model_loaded and model is not None and tokenizer is not None | |
| def get_model_info() -> Dict: | |
| """Get information about the loaded model""" | |
| if not is_model_loaded(): | |
| return {"status": "not_loaded"} | |
| try: | |
| param_count = sum(p.numel() for p in model.parameters()) | |
| avg_inference_time = np.mean(inference_times) if inference_times else 0.0 | |
| return { | |
| "status": "loaded", | |
| "model_type": type(model).__name__, | |
| "tokenizer_type": type(tokenizer).__name__, | |
| "device": str(device), | |
| "parameters": param_count, | |
| "avg_inference_time": avg_inference_time, | |
| "total_predictions": len(inference_times), | |
| "drug_keywords_count": len(DRUG_KEYWORDS), | |
| "high_risk_keywords_count": len(HIGH_RISK_KEYWORDS) | |
| } | |
| except Exception as e: | |
| logger.error(f"Error getting model info: {e}") | |
| return {"status": "error", "error": str(e)} | |
| # ======================= | |
| # Batch prediction for performance | |
| # ======================= | |
| def predict_batch(texts: List[str], batch_size: int = 8) -> List[Tuple[int, float]]: | |
| """ | |
| Batch prediction for improved performance on multiple texts | |
| """ | |
| if not texts: | |
| return [] | |
| if not is_model_loaded(): | |
| logger.error("Model not loaded for batch prediction") | |
| return [(0, 0.0) for _ in texts] | |
| results = [] | |
| try: | |
| # Process in batches | |
| for i in range(0, len(texts), batch_size): | |
| batch_texts = texts[i:i + batch_size] | |
| batch_results = [] | |
| # Process each text in the batch | |
| for text in batch_texts: | |
| pred, conf = predict(text) | |
| batch_results.append((pred, conf)) | |
| results.extend(batch_results) | |
| logger.info(f"Batch prediction completed for {len(texts)} texts") | |
| return results | |
| except Exception as e: | |
| logger.error(f"Batch prediction failed: {e}") | |
| return [(0, 0.0) for _ in texts] | |
| # ======================= | |
| # Performance monitoring | |
| # ======================= | |
| def get_performance_stats() -> Dict: | |
| """Get performance statistics""" | |
| if not inference_times: | |
| return {"status": "no_data"} | |
| return { | |
| "total_predictions": len(inference_times), | |
| "avg_inference_time": np.mean(inference_times), | |
| "min_inference_time": min(inference_times), | |
| "max_inference_time": max(inference_times), | |
| "std_inference_time": np.std(inference_times), | |
| "device": str(device) | |
| } | |
| # ======================= | |
| # Module initialization | |
| # ======================= | |
| def initialize_model(model_path: str = None) -> bool: | |
| """Initialize the prediction module""" | |
| if model_path: | |
| return load_model(model_path) | |
| return False | |
| # ======================= | |
| # Main execution for testing | |
| # ======================= | |
| if __name__ == "__main__": | |
| # Test the prediction system | |
| test_texts = [ | |
| "Hey, can you pick up some stuff from behind the metro station?", | |
| "I'm going to the grocery store to buy some milk and bread.", | |
| "The quality is really good this time, payment through crypto as usual.", | |
| "Let's meet for coffee tomorrow morning at 9 AM." | |
| ] | |
| print("Testing prediction system...") | |
| for i, text in enumerate(test_texts): | |
| pred, conf = predict(text) | |
| result = "DRUG" if pred == 1 else "NON_DRUG" | |
| print(f"Text {i+1}: {result} (confidence: {conf:.3f})") | |
| print(f" Input: {text}") | |
| print() | |
| # Print performance stats | |
| stats = get_performance_stats() | |
| print("Performance Stats:", stats) | |
| # Print model info | |
| info = get_model_info() | |
| print("Model Info:", info) |