Spaces:
Sleeping
Sleeping
Update src/predict.py
Browse files- src/predict.py +86 -172
src/predict.py
CHANGED
|
@@ -1,55 +1,53 @@
|
|
| 1 |
-
# predict.py -
|
| 2 |
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification, DistilBertConfig
|
| 3 |
import torch
|
| 4 |
import torch.nn.functional as F
|
| 5 |
import logging
|
| 6 |
import os
|
| 7 |
import json
|
|
|
|
| 8 |
|
| 9 |
-
#
|
|
|
|
|
|
|
| 10 |
logging.basicConfig(level=logging.INFO)
|
| 11 |
logger = logging.getLogger(__name__)
|
| 12 |
|
| 13 |
-
#
|
|
|
|
|
|
|
| 14 |
model = None
|
| 15 |
tokenizer = None
|
| 16 |
model_loaded = False
|
| 17 |
|
|
|
|
|
|
|
|
|
|
| 18 |
def validate_and_fix_config(model_path):
|
| 19 |
"""Validate and fix model configuration if needed"""
|
| 20 |
config_path = os.path.join(model_path, "config.json")
|
| 21 |
-
|
| 22 |
if not os.path.exists(config_path):
|
| 23 |
logger.warning(f"Config file not found at {config_path}")
|
| 24 |
return False
|
| 25 |
-
|
| 26 |
try:
|
| 27 |
with open(config_path, 'r') as f:
|
| 28 |
config_data = json.load(f)
|
| 29 |
|
| 30 |
-
# Check for the problematic configuration
|
| 31 |
dim = config_data.get('dim', 768)
|
| 32 |
n_heads = config_data.get('n_heads', 12)
|
| 33 |
-
|
| 34 |
if dim % n_heads != 0:
|
| 35 |
logger.warning(f"Configuration issue detected: dim={dim} not divisible by n_heads={n_heads}")
|
| 36 |
-
|
| 37 |
-
# Backup original config
|
| 38 |
backup_path = config_path + ".backup"
|
| 39 |
if not os.path.exists(backup_path):
|
| 40 |
-
import shutil
|
| 41 |
shutil.copy2(config_path, backup_path)
|
| 42 |
logger.info(f"Backed up original config to {backup_path}")
|
| 43 |
|
| 44 |
-
# Fix configuration
|
| 45 |
config_data['dim'] = 768
|
| 46 |
config_data['n_heads'] = 12
|
| 47 |
config_data['hidden_dim'] = 3072
|
| 48 |
-
|
| 49 |
-
# Write fixed config
|
| 50 |
with open(config_path, 'w') as f:
|
| 51 |
json.dump(config_data, f, indent=2)
|
| 52 |
-
|
| 53 |
logger.info("Fixed configuration with standard DistilBERT dimensions")
|
| 54 |
return True
|
| 55 |
|
|
@@ -60,37 +58,30 @@ def validate_and_fix_config(model_path):
|
|
| 60 |
logger.error(f"Error validating/fixing config: {e}")
|
| 61 |
return False
|
| 62 |
|
|
|
|
|
|
|
|
|
|
| 63 |
def load_model_with_fallback(model_name):
|
| 64 |
-
"""Load model with fallback strategies"""
|
| 65 |
global model, tokenizer
|
| 66 |
-
|
| 67 |
-
# Strategy 1: Try loading with config validation
|
| 68 |
if os.path.exists(model_name):
|
| 69 |
logger.info(f"Attempting to load local model from {model_name}")
|
| 70 |
-
|
| 71 |
-
# Validate and fix config first
|
| 72 |
-
config_valid = validate_and_fix_config(model_name)
|
| 73 |
-
if not config_valid:
|
| 74 |
-
logger.warning("Could not validate/fix config, trying anyway...")
|
| 75 |
-
|
| 76 |
try:
|
| 77 |
tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)
|
| 78 |
model = DistilBertForSequenceClassification.from_pretrained(
|
| 79 |
model_name,
|
| 80 |
-
ignore_mismatched_sizes=True
|
| 81 |
)
|
| 82 |
logger.info("Successfully loaded local model")
|
| 83 |
return True
|
| 84 |
-
|
| 85 |
except Exception as e:
|
| 86 |
logger.error(f"Failed to load local model: {e}")
|
| 87 |
-
|
| 88 |
-
# Strategy 2:
|
| 89 |
if os.path.exists(model_name):
|
| 90 |
try:
|
| 91 |
logger.info("Attempting to load with custom configuration...")
|
| 92 |
-
|
| 93 |
-
# Create a working config
|
| 94 |
config = DistilBertConfig(
|
| 95 |
vocab_size=30522,
|
| 96 |
max_position_embeddings=512,
|
|
@@ -106,14 +97,8 @@ def load_model_with_fallback(model_name):
|
|
| 106 |
seq_classif_dropout=0.2,
|
| 107 |
num_labels=2
|
| 108 |
)
|
| 109 |
-
|
| 110 |
-
# Load tokenizer
|
| 111 |
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
|
| 112 |
-
|
| 113 |
-
# Create model with fixed config
|
| 114 |
model = DistilBertForSequenceClassification(config)
|
| 115 |
-
|
| 116 |
-
# Try to load existing weights
|
| 117 |
weights_path = os.path.join(model_name, "pytorch_model.bin")
|
| 118 |
if os.path.exists(weights_path):
|
| 119 |
try:
|
|
@@ -122,221 +107,150 @@ def load_model_with_fallback(model_name):
|
|
| 122 |
logger.info("Loaded existing weights with custom config")
|
| 123 |
except Exception as weight_error:
|
| 124 |
logger.warning(f"Could not load weights: {weight_error}")
|
| 125 |
-
logger.info("Using randomly initialized model")
|
| 126 |
-
|
| 127 |
return True
|
| 128 |
-
|
| 129 |
except Exception as e:
|
| 130 |
logger.error(f"Custom config loading failed: {e}")
|
| 131 |
|
| 132 |
-
# Strategy 3:
|
| 133 |
try:
|
| 134 |
logger.info("Loading fallback model from HuggingFace...")
|
| 135 |
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
|
| 136 |
model = DistilBertForSequenceClassification.from_pretrained(
|
| 137 |
-
'distilbert-base-uncased',
|
| 138 |
-
num_labels=2
|
| 139 |
)
|
| 140 |
-
logger.warning("Using pre-trained DistilBERT as fallback -
|
| 141 |
return True
|
| 142 |
-
|
| 143 |
except Exception as e:
|
| 144 |
logger.error(f"Fallback model loading failed: {e}")
|
| 145 |
return False
|
| 146 |
|
|
|
|
|
|
|
|
|
|
| 147 |
def load_model(model_name="drug_classifier_model"):
|
| 148 |
-
"""Load model and tokenizer once with enhanced error handling"""
|
| 149 |
global model, tokenizer, model_loaded
|
| 150 |
-
|
| 151 |
if model_loaded:
|
| 152 |
-
return
|
| 153 |
-
|
| 154 |
try:
|
| 155 |
-
# Use the enhanced loading function
|
| 156 |
success = load_model_with_fallback(model_name)
|
| 157 |
-
|
| 158 |
if not success:
|
| 159 |
raise RuntimeError("All model loading strategies failed")
|
| 160 |
-
|
| 161 |
-
model.eval() # Set to evaluation mode
|
| 162 |
-
|
| 163 |
-
# Move to GPU if available
|
| 164 |
if torch.cuda.is_available():
|
| 165 |
-
model
|
| 166 |
logger.info("Model moved to GPU")
|
| 167 |
-
|
| 168 |
model_loaded = True
|
| 169 |
-
logger.info(
|
| 170 |
-
|
| 171 |
-
# Log model configuration if available
|
| 172 |
-
if hasattr(model.config, 'id2label'):
|
| 173 |
-
logger.info(f"Model labels: {model.config.id2label}")
|
| 174 |
-
|
| 175 |
except Exception as e:
|
| 176 |
logger.error(f"Failed to load model or tokenizer: {e}")
|
| 177 |
raise
|
| 178 |
|
|
|
|
|
|
|
|
|
|
| 179 |
def predict(text, confidence_threshold=0.5):
|
| 180 |
-
"""
|
| 181 |
-
Predict whether the input text is DRUG (1) or NON_DRUG (0).
|
| 182 |
-
|
| 183 |
-
Args:
|
| 184 |
-
text (str): Input text to classify
|
| 185 |
-
confidence_threshold (float): Threshold for DRUG classification (default: 0.5)
|
| 186 |
-
|
| 187 |
-
Returns:
|
| 188 |
-
tuple: (label, drug_probability)
|
| 189 |
-
- label: 1 for DRUG, 0 for NON_DRUG
|
| 190 |
-
- drug_probability: float between 0 and 1
|
| 191 |
-
"""
|
| 192 |
-
# Ensure model is loaded
|
| 193 |
if not model_loaded:
|
| 194 |
load_model()
|
| 195 |
-
|
| 196 |
-
# Input validation
|
| 197 |
if not text or not isinstance(text, str):
|
| 198 |
-
logger.warning("Empty or invalid input text provided to predict")
|
| 199 |
return 0, 0.0
|
| 200 |
-
|
| 201 |
text = text.strip()
|
| 202 |
-
if
|
| 203 |
-
logger.warning("Empty text after stripping provided to predict")
|
| 204 |
return 0, 0.0
|
| 205 |
-
|
| 206 |
try:
|
| 207 |
-
# Tokenize input - use same max_length as training
|
| 208 |
inputs = tokenizer(
|
| 209 |
-
text,
|
| 210 |
-
return_tensors="pt",
|
| 211 |
-
truncation=True,
|
| 212 |
-
padding=True,
|
| 213 |
-
max_length=256 # Match training script
|
| 214 |
)
|
| 215 |
-
|
| 216 |
-
# Move inputs to same device as model
|
| 217 |
if torch.cuda.is_available() and next(model.parameters()).is_cuda:
|
| 218 |
inputs = {k: v.cuda() for k, v in inputs.items()}
|
| 219 |
-
|
| 220 |
-
# Make prediction
|
| 221 |
with torch.no_grad():
|
| 222 |
outputs = model(**inputs)
|
| 223 |
probs = F.softmax(outputs.logits, dim=-1)
|
| 224 |
-
|
| 225 |
-
non_drug_prob = probs[0][0].item() # Probability for NON_DRUG (index 0)
|
| 226 |
-
drug_prob = probs[0][1].item() # Probability for DRUG (index 1)
|
| 227 |
-
|
| 228 |
-
# Apply threshold for classification
|
| 229 |
pred_label = 1 if drug_prob > confidence_threshold else 0
|
| 230 |
-
|
| 231 |
-
# Log prediction details
|
| 232 |
-
logger.info(f"Prediction for: '{text[:100]}{'...' if len(text) > 100 else ''}'")
|
| 233 |
-
logger.info(f" Result: {'DRUG' if pred_label == 1 else 'NON_DRUG'}")
|
| 234 |
-
logger.info(f" DRUG probability: {drug_prob:.4f} ({drug_prob*100:.2f}%)")
|
| 235 |
-
logger.info(f" NON_DRUG probability: {non_drug_prob:.4f} ({non_drug_prob*100:.2f}%)")
|
| 236 |
-
logger.info(f" Confidence: {max(drug_prob, non_drug_prob):.4f}")
|
| 237 |
-
|
| 238 |
return pred_label, drug_prob
|
| 239 |
-
|
| 240 |
except Exception as e:
|
| 241 |
-
logger.error(f"
|
| 242 |
-
logger.error(f"Input text: {text}")
|
| 243 |
return 0, 0.0
|
| 244 |
|
|
|
|
|
|
|
|
|
|
| 245 |
def predict_batch(texts, confidence_threshold=0.5):
|
| 246 |
-
"""
|
| 247 |
-
Predict for multiple texts at once (more efficient)
|
| 248 |
-
|
| 249 |
-
Args:
|
| 250 |
-
texts (list): List of texts to classify
|
| 251 |
-
confidence_threshold (float): Threshold for DRUG classification
|
| 252 |
-
|
| 253 |
-
Returns:
|
| 254 |
-
list: List of (label, drug_probability) tuples
|
| 255 |
-
"""
|
| 256 |
if not model_loaded:
|
| 257 |
load_model()
|
| 258 |
-
|
| 259 |
if not texts or not isinstance(texts, list):
|
| 260 |
-
logger.warning("Empty or invalid text list provided to predict_batch")
|
| 261 |
return []
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
for i, text in enumerate(texts):
|
| 267 |
-
if text and isinstance(text, str) and text.strip():
|
| 268 |
-
valid_texts.append(text.strip())
|
| 269 |
valid_indices.append(i)
|
| 270 |
-
|
| 271 |
if not valid_texts:
|
| 272 |
-
|
| 273 |
-
return [(0, 0.0)] * len(texts)
|
| 274 |
-
|
| 275 |
try:
|
| 276 |
-
|
| 277 |
-
inputs = tokenizer(
|
| 278 |
-
valid_texts,
|
| 279 |
-
return_tensors="pt",
|
| 280 |
-
truncation=True,
|
| 281 |
-
padding=True,
|
| 282 |
-
max_length=256
|
| 283 |
-
)
|
| 284 |
-
|
| 285 |
-
# Move to device
|
| 286 |
if torch.cuda.is_available() and next(model.parameters()).is_cuda:
|
| 287 |
-
inputs = {k: v.cuda() for k,
|
| 288 |
-
|
| 289 |
-
# Make predictions
|
| 290 |
with torch.no_grad():
|
| 291 |
outputs = model(**inputs)
|
| 292 |
probs = F.softmax(outputs.logits, dim=-1)
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
results = [(0, 0.0)] * len(texts)
|
| 296 |
-
for i, valid_idx in enumerate(valid_indices):
|
| 297 |
drug_prob = probs[i][1].item()
|
| 298 |
pred_label = 1 if drug_prob > confidence_threshold else 0
|
| 299 |
-
results[
|
| 300 |
-
|
| 301 |
-
logger.info(f"Batch prediction completed for {len(valid_texts)} texts")
|
| 302 |
return results
|
| 303 |
-
|
| 304 |
except Exception as e:
|
| 305 |
-
logger.error(f"
|
| 306 |
-
return [(0,
|
| 307 |
|
|
|
|
|
|
|
|
|
|
| 308 |
def test_predictions():
|
| 309 |
-
"""Test function to verify model predictions"""
|
| 310 |
logger.info("Running prediction tests...")
|
| 311 |
-
|
| 312 |
test_cases = [
|
| 313 |
-
# Should be DRUG
|
| 314 |
"Bro, check the Insta DM. That the white or the blue? White, straight from Mumbai. Cool, payment through crypto, right? Who's bringing the stuff? Raj, Tabs, Weed and Coke. Let's not overdose this time.",
|
| 315 |
"Got some quality hash and charas ready for pickup tonight",
|
| 316 |
"MDMA tabs are available, payment through crypto only",
|
| 317 |
-
|
| 318 |
-
# Should be NON_DRUG
|
| 319 |
"Hey, how's your work going today? Let's meet for coffee this evening.",
|
| 320 |
"The weather is really nice today, perfect for a walk in the park",
|
| 321 |
"I need to finish my project by tomorrow, can you help me?",
|
| 322 |
]
|
| 323 |
-
|
| 324 |
-
for i, text in enumerate(test_cases, 1):
|
| 325 |
-
logger.info(f"\n--- Test Case {i} ---")
|
| 326 |
label, prob = predict(text)
|
| 327 |
-
expected = "DRUG" if i
|
| 328 |
-
actual = "DRUG" if label
|
| 329 |
-
logger.info(f"
|
| 330 |
-
|
| 331 |
-
# ===========================
|
| 332 |
-
# HF Spaces / Production Ready
|
| 333 |
-
# ===========================
|
| 334 |
|
|
|
|
|
|
|
|
|
|
| 335 |
if __name__ == "__main__":
|
| 336 |
-
# Load model once
|
| 337 |
load_model()
|
| 338 |
-
|
| 339 |
-
# Optional: run test predictions to verify setup
|
| 340 |
test_predictions()
|
| 341 |
|
| 342 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# predict.py - Production + Interactive Compatible Version
|
| 2 |
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification, DistilBertConfig
|
| 3 |
import torch
|
| 4 |
import torch.nn.functional as F
|
| 5 |
import logging
|
| 6 |
import os
|
| 7 |
import json
|
| 8 |
+
import shutil
|
| 9 |
|
| 10 |
+
# =======================
|
| 11 |
+
# Logging configuration
|
| 12 |
+
# =======================
|
| 13 |
logging.basicConfig(level=logging.INFO)
|
| 14 |
logger = logging.getLogger(__name__)
|
| 15 |
|
| 16 |
+
# =======================
|
| 17 |
+
# Global variables
|
| 18 |
+
# =======================
|
| 19 |
model = None
|
| 20 |
tokenizer = None
|
| 21 |
model_loaded = False
|
| 22 |
|
| 23 |
+
# =======================
|
| 24 |
+
# Config validation/fix
|
| 25 |
+
# =======================
|
| 26 |
def validate_and_fix_config(model_path):
|
| 27 |
"""Validate and fix model configuration if needed"""
|
| 28 |
config_path = os.path.join(model_path, "config.json")
|
|
|
|
| 29 |
if not os.path.exists(config_path):
|
| 30 |
logger.warning(f"Config file not found at {config_path}")
|
| 31 |
return False
|
|
|
|
| 32 |
try:
|
| 33 |
with open(config_path, 'r') as f:
|
| 34 |
config_data = json.load(f)
|
| 35 |
|
|
|
|
| 36 |
dim = config_data.get('dim', 768)
|
| 37 |
n_heads = config_data.get('n_heads', 12)
|
|
|
|
| 38 |
if dim % n_heads != 0:
|
| 39 |
logger.warning(f"Configuration issue detected: dim={dim} not divisible by n_heads={n_heads}")
|
|
|
|
|
|
|
| 40 |
backup_path = config_path + ".backup"
|
| 41 |
if not os.path.exists(backup_path):
|
|
|
|
| 42 |
shutil.copy2(config_path, backup_path)
|
| 43 |
logger.info(f"Backed up original config to {backup_path}")
|
| 44 |
|
| 45 |
+
# Fix configuration
|
| 46 |
config_data['dim'] = 768
|
| 47 |
config_data['n_heads'] = 12
|
| 48 |
config_data['hidden_dim'] = 3072
|
|
|
|
|
|
|
| 49 |
with open(config_path, 'w') as f:
|
| 50 |
json.dump(config_data, f, indent=2)
|
|
|
|
| 51 |
logger.info("Fixed configuration with standard DistilBERT dimensions")
|
| 52 |
return True
|
| 53 |
|
|
|
|
| 58 |
logger.error(f"Error validating/fixing config: {e}")
|
| 59 |
return False
|
| 60 |
|
| 61 |
+
# =======================
|
| 62 |
+
# Model loading with fallback
|
| 63 |
+
# =======================
|
| 64 |
def load_model_with_fallback(model_name):
|
|
|
|
| 65 |
global model, tokenizer
|
| 66 |
+
# Strategy 1: Load local model
|
|
|
|
| 67 |
if os.path.exists(model_name):
|
| 68 |
logger.info(f"Attempting to load local model from {model_name}")
|
| 69 |
+
validate_and_fix_config(model_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
try:
|
| 71 |
tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)
|
| 72 |
model = DistilBertForSequenceClassification.from_pretrained(
|
| 73 |
model_name,
|
| 74 |
+
ignore_mismatched_sizes=True
|
| 75 |
)
|
| 76 |
logger.info("Successfully loaded local model")
|
| 77 |
return True
|
|
|
|
| 78 |
except Exception as e:
|
| 79 |
logger.error(f"Failed to load local model: {e}")
|
| 80 |
+
|
| 81 |
+
# Strategy 2: Custom config + weights
|
| 82 |
if os.path.exists(model_name):
|
| 83 |
try:
|
| 84 |
logger.info("Attempting to load with custom configuration...")
|
|
|
|
|
|
|
| 85 |
config = DistilBertConfig(
|
| 86 |
vocab_size=30522,
|
| 87 |
max_position_embeddings=512,
|
|
|
|
| 97 |
seq_classif_dropout=0.2,
|
| 98 |
num_labels=2
|
| 99 |
)
|
|
|
|
|
|
|
| 100 |
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
|
|
|
|
|
|
|
| 101 |
model = DistilBertForSequenceClassification(config)
|
|
|
|
|
|
|
| 102 |
weights_path = os.path.join(model_name, "pytorch_model.bin")
|
| 103 |
if os.path.exists(weights_path):
|
| 104 |
try:
|
|
|
|
| 107 |
logger.info("Loaded existing weights with custom config")
|
| 108 |
except Exception as weight_error:
|
| 109 |
logger.warning(f"Could not load weights: {weight_error}")
|
|
|
|
|
|
|
| 110 |
return True
|
|
|
|
| 111 |
except Exception as e:
|
| 112 |
logger.error(f"Custom config loading failed: {e}")
|
| 113 |
|
| 114 |
+
# Strategy 3: HuggingFace fallback
|
| 115 |
try:
|
| 116 |
logger.info("Loading fallback model from HuggingFace...")
|
| 117 |
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
|
| 118 |
model = DistilBertForSequenceClassification.from_pretrained(
|
| 119 |
+
'distilbert-base-uncased', num_labels=2
|
|
|
|
| 120 |
)
|
| 121 |
+
logger.warning("Using pre-trained DistilBERT as fallback - retraining recommended")
|
| 122 |
return True
|
|
|
|
| 123 |
except Exception as e:
|
| 124 |
logger.error(f"Fallback model loading failed: {e}")
|
| 125 |
return False
|
| 126 |
|
| 127 |
+
# =======================
|
| 128 |
+
# Load model globally
|
| 129 |
+
# =======================
|
| 130 |
def load_model(model_name="drug_classifier_model"):
|
|
|
|
| 131 |
global model, tokenizer, model_loaded
|
|
|
|
| 132 |
if model_loaded:
|
| 133 |
+
return
|
|
|
|
| 134 |
try:
|
|
|
|
| 135 |
success = load_model_with_fallback(model_name)
|
|
|
|
| 136 |
if not success:
|
| 137 |
raise RuntimeError("All model loading strategies failed")
|
| 138 |
+
model.eval()
|
|
|
|
|
|
|
|
|
|
| 139 |
if torch.cuda.is_available():
|
| 140 |
+
model.cuda()
|
| 141 |
logger.info("Model moved to GPU")
|
|
|
|
| 142 |
model_loaded = True
|
| 143 |
+
logger.info("Successfully loaded model and tokenizer")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
except Exception as e:
|
| 145 |
logger.error(f"Failed to load model or tokenizer: {e}")
|
| 146 |
raise
|
| 147 |
|
| 148 |
+
# =======================
|
| 149 |
+
# Single prediction
|
| 150 |
+
# =======================
|
| 151 |
def predict(text, confidence_threshold=0.5):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
if not model_loaded:
|
| 153 |
load_model()
|
|
|
|
|
|
|
| 154 |
if not text or not isinstance(text, str):
|
|
|
|
| 155 |
return 0, 0.0
|
|
|
|
| 156 |
text = text.strip()
|
| 157 |
+
if not text:
|
|
|
|
| 158 |
return 0, 0.0
|
|
|
|
| 159 |
try:
|
|
|
|
| 160 |
inputs = tokenizer(
|
| 161 |
+
text, return_tensors="pt", truncation=True, padding=True, max_length=256
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
)
|
|
|
|
|
|
|
| 163 |
if torch.cuda.is_available() and next(model.parameters()).is_cuda:
|
| 164 |
inputs = {k: v.cuda() for k, v in inputs.items()}
|
|
|
|
|
|
|
| 165 |
with torch.no_grad():
|
| 166 |
outputs = model(**inputs)
|
| 167 |
probs = F.softmax(outputs.logits, dim=-1)
|
| 168 |
+
drug_prob = probs[0][1].item()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
pred_label = 1 if drug_prob > confidence_threshold else 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
return pred_label, drug_prob
|
|
|
|
| 171 |
except Exception as e:
|
| 172 |
+
logger.error(f"Prediction error: {e}")
|
|
|
|
| 173 |
return 0, 0.0
|
| 174 |
|
| 175 |
+
# =======================
|
| 176 |
+
# Batch prediction
|
| 177 |
+
# =======================
|
| 178 |
def predict_batch(texts, confidence_threshold=0.5):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
if not model_loaded:
|
| 180 |
load_model()
|
|
|
|
| 181 |
if not texts or not isinstance(texts, list):
|
|
|
|
| 182 |
return []
|
| 183 |
+
valid_texts, valid_indices = [], []
|
| 184 |
+
for i, t in enumerate(texts):
|
| 185 |
+
if t and isinstance(t, str) and t.strip():
|
| 186 |
+
valid_texts.append(t.strip())
|
|
|
|
|
|
|
|
|
|
| 187 |
valid_indices.append(i)
|
|
|
|
| 188 |
if not valid_texts:
|
| 189 |
+
return [(0,0.0)]*len(texts)
|
|
|
|
|
|
|
| 190 |
try:
|
| 191 |
+
inputs = tokenizer(valid_texts, return_tensors="pt", truncation=True, padding=True, max_length=256)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
if torch.cuda.is_available() and next(model.parameters()).is_cuda:
|
| 193 |
+
inputs = {k: v.cuda() for k,v in inputs.items()}
|
|
|
|
|
|
|
| 194 |
with torch.no_grad():
|
| 195 |
outputs = model(**inputs)
|
| 196 |
probs = F.softmax(outputs.logits, dim=-1)
|
| 197 |
+
results = [(0,0.0)]*len(texts)
|
| 198 |
+
for i, idx in enumerate(valid_indices):
|
|
|
|
|
|
|
| 199 |
drug_prob = probs[i][1].item()
|
| 200 |
pred_label = 1 if drug_prob > confidence_threshold else 0
|
| 201 |
+
results[idx] = (pred_label, drug_prob)
|
|
|
|
|
|
|
| 202 |
return results
|
|
|
|
| 203 |
except Exception as e:
|
| 204 |
+
logger.error(f"Batch prediction error: {e}")
|
| 205 |
+
return [(0,0.0)]*len(texts)
|
| 206 |
|
| 207 |
+
# =======================
|
| 208 |
+
# Test predictions
|
| 209 |
+
# =======================
|
| 210 |
def test_predictions():
|
|
|
|
| 211 |
logger.info("Running prediction tests...")
|
|
|
|
| 212 |
test_cases = [
|
|
|
|
| 213 |
"Bro, check the Insta DM. That the white or the blue? White, straight from Mumbai. Cool, payment through crypto, right? Who's bringing the stuff? Raj, Tabs, Weed and Coke. Let's not overdose this time.",
|
| 214 |
"Got some quality hash and charas ready for pickup tonight",
|
| 215 |
"MDMA tabs are available, payment through crypto only",
|
|
|
|
|
|
|
| 216 |
"Hey, how's your work going today? Let's meet for coffee this evening.",
|
| 217 |
"The weather is really nice today, perfect for a walk in the park",
|
| 218 |
"I need to finish my project by tomorrow, can you help me?",
|
| 219 |
]
|
| 220 |
+
for i, text in enumerate(test_cases,1):
|
|
|
|
|
|
|
| 221 |
label, prob = predict(text)
|
| 222 |
+
expected = "DRUG" if i<=3 else "NON_DRUG"
|
| 223 |
+
actual = "DRUG" if label==1 else "NON_DRUG"
|
| 224 |
+
logger.info(f"Test {i}: Expected={expected}, Got={actual}, Prob={prob:.4f}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
|
| 226 |
+
# =======================
|
| 227 |
+
# Main interactive loop
|
| 228 |
+
# =======================
|
| 229 |
if __name__ == "__main__":
|
|
|
|
| 230 |
load_model()
|
|
|
|
|
|
|
| 231 |
test_predictions()
|
| 232 |
|
| 233 |
+
print("\n" + "="*50)
|
| 234 |
+
print("Interactive Drug Detection")
|
| 235 |
+
print("Type 'quit' to exit")
|
| 236 |
+
print("="*50)
|
| 237 |
+
|
| 238 |
+
while True:
|
| 239 |
+
try:
|
| 240 |
+
user_input = input("\nEnter text: ").strip()
|
| 241 |
+
if user_input.lower() in ['quit','exit','q']:
|
| 242 |
+
break
|
| 243 |
+
if user_input:
|
| 244 |
+
label, prob = predict(user_input)
|
| 245 |
+
result = "🚨 DRUG" if label==1 else "✅ NON_DRUG"
|
| 246 |
+
confidence = max(prob, 1-prob)
|
| 247 |
+
print(f"Result: {result}")
|
| 248 |
+
print(f"Drug Probability: {prob*100:.2f}%")
|
| 249 |
+
print(f"Confidence: {confidence*100:.2f}%")
|
| 250 |
+
else:
|
| 251 |
+
print("Please enter some text.")
|
| 252 |
+
except KeyboardInterrupt:
|
| 253 |
+
print("\nExiting...")
|
| 254 |
+
break
|
| 255 |
+
except Exception as e:
|
| 256 |
+
print(f"Error: {e}")
|