lawlevisan commited on
Commit
c452442
·
verified ·
1 Parent(s): dc729a0

Update src/predict.py

Browse files
Files changed (1) hide show
  1. src/predict.py +364 -49
src/predict.py CHANGED
@@ -1,50 +1,365 @@
 
 
 
 
 
1
  import os
2
- from pathlib import Path
3
- from typing import List
4
-
5
- # === Base Directories & Defaults ===
6
- BASE_DIR = Path(__file__).resolve().parent
7
- DEFAULT_MODEL_PATH = BASE_DIR / "drug_classifier_model"
8
- DEFAULT_WHISPER_MODEL = "base"
9
- DEFAULT_MAX_LENGTH = 256
10
- DEFAULT_THRESHOLD = 0.70
11
-
12
- # === Keywords ===
13
- DRUG_KEYWORDS = [
14
- "stuff", "package", "goods", "deal", "pick up", "pickup", "stash", "green",
15
- "weed", "pot", "coke", "cocaine", "white", "powder", "score", "high",
16
- "gram", "g", "pill", "tabs", "md", "mdma", "lsd", "charas", "hash", "ganja",
17
- "dope", "joint", "puff", "trip", "syringe", "needle", "gear", "supply",
18
- "quality", "batch", "hook me up", "hookup", "overdose", "rave", "party"
19
- ]
20
-
21
- HIGH_RISK_KEYWORDS = [
22
- "coke", "cocaine", "weed", "pot", "tabs", "mdma", "lsd", "charas", "hash",
23
- "ganja", "dope", "overdose", "syringe", "needle", "gear"
24
- ]
25
-
26
- # === Production Config ===
27
- class ProductionConfig:
28
- def __init__(self):
29
- # Model settings
30
- self.MODEL_PATH = Path(os.getenv("MODEL_PATH", str(DEFAULT_MODEL_PATH)))
31
- self.WHISPER_MODEL = os.getenv("WHISPER_MODEL", DEFAULT_WHISPER_MODEL)
32
- self.MAX_LENGTH = int(os.getenv("MAX_LENGTH", DEFAULT_MAX_LENGTH))
33
- self.THRESHOLD = float(os.getenv("THRESHOLD", DEFAULT_THRESHOLD))
34
-
35
- # Limits
36
- self.MAX_FILE_SIZE_MB = int(os.getenv("MAX_FILE_SIZE_MB", "50"))
37
- self.MAX_AUDIO_DURATION = int(os.getenv("MAX_AUDIO_DURATION", "300")) # 5 mins
38
- self.ALLOWED_EXTENSIONS: List[str] = ["wav", "mp3", "m4a", "flac", "ogg"]
39
-
40
- # Security
41
- self.RATE_LIMIT_REQUESTS = int(os.getenv("RATE_LIMIT_REQUESTS", "10"))
42
- self.RATE_LIMIT_WINDOW = int(os.getenv("RATE_LIMIT_WINDOW", "3600")) # 1h
43
- self.ENABLE_LOGGING = os.getenv("ENABLE_LOGGING", "true").lower() == "true"
44
-
45
- def is_allowed_file(self, filename: str) -> bool:
46
- """Check if the file has a valid extension"""
47
- return any(filename.lower().endswith(ext) for ext in self.ALLOWED_EXTENSIONS)
48
-
49
- # === Global config instance ===
50
- config = ProductionConfig()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # predict.py - Fixed version with configuration validation
2
+ from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification, DistilBertConfig
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import logging
6
  import os
7
+ import json
8
+
9
+ # Configure logging
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
+
13
+ # Global variables for model and tokenizer (loaded once)
14
+ model = None
15
+ tokenizer = None
16
+ model_loaded = False
17
+
18
+ def validate_and_fix_config(model_path):
19
+ """Validate and fix model configuration if needed"""
20
+ config_path = os.path.join(model_path, "config.json")
21
+
22
+ if not os.path.exists(config_path):
23
+ logger.warning(f"Config file not found at {config_path}")
24
+ return False
25
+
26
+ try:
27
+ with open(config_path, 'r') as f:
28
+ config_data = json.load(f)
29
+
30
+ # Check for the problematic configuration
31
+ dim = config_data.get('dim', 768)
32
+ n_heads = config_data.get('n_heads', 12)
33
+
34
+ if dim % n_heads != 0:
35
+ logger.warning(f"Configuration issue detected: dim={dim} not divisible by n_heads={n_heads}")
36
+
37
+ # Backup original config
38
+ backup_path = config_path + ".backup"
39
+ if not os.path.exists(backup_path):
40
+ import shutil
41
+ shutil.copy2(config_path, backup_path)
42
+ logger.info(f"Backed up original config to {backup_path}")
43
+
44
+ # Fix configuration with standard DistilBERT values
45
+ config_data['dim'] = 768
46
+ config_data['n_heads'] = 12
47
+ config_data['hidden_dim'] = 3072
48
+
49
+ # Write fixed config
50
+ with open(config_path, 'w') as f:
51
+ json.dump(config_data, f, indent=2)
52
+
53
+ logger.info("Fixed configuration with standard DistilBERT dimensions")
54
+ return True
55
+
56
+ logger.info("Configuration is valid")
57
+ return True
58
+
59
+ except Exception as e:
60
+ logger.error(f"Error validating/fixing config: {e}")
61
+ return False
62
+
63
+ def load_model_with_fallback(model_name):
64
+ """Load model with fallback strategies"""
65
+ global model, tokenizer
66
+
67
+ # Strategy 1: Try loading with config validation
68
+ if os.path.exists(model_name):
69
+ logger.info(f"Attempting to load local model from {model_name}")
70
+
71
+ # Validate and fix config first
72
+ config_valid = validate_and_fix_config(model_name)
73
+ if not config_valid:
74
+ logger.warning("Could not validate/fix config, trying anyway...")
75
+
76
+ try:
77
+ tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)
78
+ model = DistilBertForSequenceClassification.from_pretrained(
79
+ model_name,
80
+ ignore_mismatched_sizes=True # This helps with dimension issues
81
+ )
82
+ logger.info("Successfully loaded local model")
83
+ return True
84
+
85
+ except Exception as e:
86
+ logger.error(f"Failed to load local model: {e}")
87
+
88
+ # Strategy 2: Create a compatible model with existing weights
89
+ if os.path.exists(model_name):
90
+ try:
91
+ logger.info("Attempting to load with custom configuration...")
92
+
93
+ # Create a working config
94
+ config = DistilBertConfig(
95
+ vocab_size=30522,
96
+ max_position_embeddings=512,
97
+ dim=768,
98
+ n_layers=6,
99
+ n_heads=12,
100
+ hidden_dim=3072,
101
+ dropout=0.1,
102
+ attention_dropout=0.1,
103
+ activation='gelu',
104
+ initializer_range=0.02,
105
+ qa_dropout=0.1,
106
+ seq_classif_dropout=0.2,
107
+ num_labels=2
108
+ )
109
+
110
+ # Load tokenizer
111
+ tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
112
+
113
+ # Create model with fixed config
114
+ model = DistilBertForSequenceClassification(config)
115
+
116
+ # Try to load existing weights
117
+ weights_path = os.path.join(model_name, "pytorch_model.bin")
118
+ if os.path.exists(weights_path):
119
+ try:
120
+ state_dict = torch.load(weights_path, map_location='cpu')
121
+ model.load_state_dict(state_dict, strict=False)
122
+ logger.info("Loaded existing weights with custom config")
123
+ except Exception as weight_error:
124
+ logger.warning(f"Could not load weights: {weight_error}")
125
+ logger.info("Using randomly initialized model")
126
+
127
+ return True
128
+
129
+ except Exception as e:
130
+ logger.error(f"Custom config loading failed: {e}")
131
+
132
+ # Strategy 3: Use pre-trained DistilBERT as fallback
133
+ try:
134
+ logger.info("Loading fallback model from HuggingFace...")
135
+ tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
136
+ model = DistilBertForSequenceClassification.from_pretrained(
137
+ 'distilbert-base-uncased',
138
+ num_labels=2
139
+ )
140
+ logger.warning("Using pre-trained DistilBERT as fallback - will need retraining for drug classification")
141
+ return True
142
+
143
+ except Exception as e:
144
+ logger.error(f"Fallback model loading failed: {e}")
145
+ return False
146
+
147
+ def load_model(model_name="drug_classifier_model"):
148
+ """Load model and tokenizer once with enhanced error handling"""
149
+ global model, tokenizer, model_loaded
150
+
151
+ if model_loaded:
152
+ return # Already loaded
153
+
154
+ try:
155
+ # Use the enhanced loading function
156
+ success = load_model_with_fallback(model_name)
157
+
158
+ if not success:
159
+ raise RuntimeError("All model loading strategies failed")
160
+
161
+ model.eval() # Set to evaluation mode
162
+
163
+ # Move to GPU if available
164
+ if torch.cuda.is_available():
165
+ model = model.cuda()
166
+ logger.info("Model moved to GPU")
167
+
168
+ model_loaded = True
169
+ logger.info(f"Successfully loaded model and tokenizer")
170
+
171
+ # Log model configuration if available
172
+ if hasattr(model.config, 'id2label'):
173
+ logger.info(f"Model labels: {model.config.id2label}")
174
+
175
+ except Exception as e:
176
+ logger.error(f"Failed to load model or tokenizer: {e}")
177
+ raise
178
+
179
+ def predict(text, confidence_threshold=0.5):
180
+ """
181
+ Predict whether the input text is DRUG (1) or NON_DRUG (0).
182
+
183
+ Args:
184
+ text (str): Input text to classify
185
+ confidence_threshold (float): Threshold for DRUG classification (default: 0.5)
186
+
187
+ Returns:
188
+ tuple: (label, drug_probability)
189
+ - label: 1 for DRUG, 0 for NON_DRUG
190
+ - drug_probability: float between 0 and 1
191
+ """
192
+ # Ensure model is loaded
193
+ if not model_loaded:
194
+ load_model()
195
+
196
+ # Input validation
197
+ if not text or not isinstance(text, str):
198
+ logger.warning("Empty or invalid input text provided to predict")
199
+ return 0, 0.0
200
+
201
+ text = text.strip()
202
+ if len(text) == 0:
203
+ logger.warning("Empty text after stripping provided to predict")
204
+ return 0, 0.0
205
+
206
+ try:
207
+ # Tokenize input - use same max_length as training
208
+ inputs = tokenizer(
209
+ text,
210
+ return_tensors="pt",
211
+ truncation=True,
212
+ padding=True,
213
+ max_length=256 # Match training script
214
+ )
215
+
216
+ # Move inputs to same device as model
217
+ if torch.cuda.is_available() and next(model.parameters()).is_cuda:
218
+ inputs = {k: v.cuda() for k, v in inputs.items()}
219
+
220
+ # Make prediction
221
+ with torch.no_grad():
222
+ outputs = model(**inputs)
223
+ probs = F.softmax(outputs.logits, dim=-1)
224
+
225
+ non_drug_prob = probs[0][0].item() # Probability for NON_DRUG (index 0)
226
+ drug_prob = probs[0][1].item() # Probability for DRUG (index 1)
227
+
228
+ # Apply threshold for classification
229
+ pred_label = 1 if drug_prob > confidence_threshold else 0
230
+
231
+ # Log prediction details
232
+ logger.info(f"Prediction for: '{text[:100]}{'...' if len(text) > 100 else ''}'")
233
+ logger.info(f" Result: {'DRUG' if pred_label == 1 else 'NON_DRUG'}")
234
+ logger.info(f" DRUG probability: {drug_prob:.4f} ({drug_prob*100:.2f}%)")
235
+ logger.info(f" NON_DRUG probability: {non_drug_prob:.4f} ({non_drug_prob*100:.2f}%)")
236
+ logger.info(f" Confidence: {max(drug_prob, non_drug_prob):.4f}")
237
+
238
+ return pred_label, drug_prob
239
+
240
+ except Exception as e:
241
+ logger.error(f"Error during prediction: {e}")
242
+ logger.error(f"Input text: {text}")
243
+ return 0, 0.0
244
+
245
+ def predict_batch(texts, confidence_threshold=0.5):
246
+ """
247
+ Predict for multiple texts at once (more efficient)
248
+
249
+ Args:
250
+ texts (list): List of texts to classify
251
+ confidence_threshold (float): Threshold for DRUG classification
252
+
253
+ Returns:
254
+ list: List of (label, drug_probability) tuples
255
+ """
256
+ if not model_loaded:
257
+ load_model()
258
+
259
+ if not texts or not isinstance(texts, list):
260
+ logger.warning("Empty or invalid text list provided to predict_batch")
261
+ return []
262
+
263
+ # Filter out empty texts and keep track of original indices
264
+ valid_texts = []
265
+ valid_indices = []
266
+ for i, text in enumerate(texts):
267
+ if text and isinstance(text, str) and text.strip():
268
+ valid_texts.append(text.strip())
269
+ valid_indices.append(i)
270
+
271
+ if not valid_texts:
272
+ logger.warning("No valid texts found in batch")
273
+ return [(0, 0.0)] * len(texts)
274
+
275
+ try:
276
+ # Tokenize all texts
277
+ inputs = tokenizer(
278
+ valid_texts,
279
+ return_tensors="pt",
280
+ truncation=True,
281
+ padding=True,
282
+ max_length=256
283
+ )
284
+
285
+ # Move to device
286
+ if torch.cuda.is_available() and next(model.parameters()).is_cuda:
287
+ inputs = {k: v.cuda() for k, v in inputs.items()}
288
+
289
+ # Make predictions
290
+ with torch.no_grad():
291
+ outputs = model(**inputs)
292
+ probs = F.softmax(outputs.logits, dim=-1)
293
+
294
+ # Create results array with proper indexing
295
+ results = [(0, 0.0)] * len(texts)
296
+ for i, valid_idx in enumerate(valid_indices):
297
+ drug_prob = probs[i][1].item()
298
+ pred_label = 1 if drug_prob > confidence_threshold else 0
299
+ results[valid_idx] = (pred_label, drug_prob)
300
+
301
+ logger.info(f"Batch prediction completed for {len(valid_texts)} texts")
302
+ return results
303
+
304
+ except Exception as e:
305
+ logger.error(f"Error during batch prediction: {e}")
306
+ return [(0, 0.0)] * len(texts)
307
+
308
+ def test_predictions():
309
+ """Test function to verify model predictions"""
310
+ logger.info("Running prediction tests...")
311
+
312
+ test_cases = [
313
+ # Should be DRUG
314
+ "Bro, check the Insta DM. That the white or the blue? White, straight from Mumbai. Cool, payment through crypto, right? Who's bringing the stuff? Raj, Tabs, Weed and Coke. Let's not overdose this time.",
315
+ "Got some quality hash and charas ready for pickup tonight",
316
+ "MDMA tabs are available, payment through crypto only",
317
+
318
+ # Should be NON_DRUG
319
+ "Hey, how's your work going today? Let's meet for coffee this evening.",
320
+ "The weather is really nice today, perfect for a walk in the park",
321
+ "I need to finish my project by tomorrow, can you help me?",
322
+ ]
323
+
324
+ for i, text in enumerate(test_cases, 1):
325
+ logger.info(f"\n--- Test Case {i} ---")
326
+ label, prob = predict(text)
327
+ expected = "DRUG" if i <= 3 else "NON_DRUG"
328
+ actual = "DRUG" if label == 1 else "NON_DRUG"
329
+ logger.info(f"Expected: {expected}, Got: {actual}, Probability: {prob:.4f}")
330
+
331
+ if __name__ == "__main__":
332
+ # Load model once when script is run directly
333
+ load_model()
334
+
335
+ # Run tests
336
+ test_predictions()
337
+
338
+ # Interactive testing
339
+ if os.getenv("ENABLE_INTERACTIVE", "false").lower() == "true":
340
+ print("\n" + "="*50)
341
+ print("Interactive Drug Detection Testing")
342
+ print("Enter text to classify (or 'quit' to exit)")
343
+ print("="*50)
344
+
345
+ while True:
346
+ try:
347
+ user_input = input("\nEnter text: ").strip()
348
+ if user_input.lower() in ['quit', 'exit', 'q']:
349
+ break
350
+
351
+ if user_input:
352
+ label, prob = predict(user_input)
353
+ result = "🚨 DRUG" if label == 1 else "✅ NON_DRUG"
354
+ confidence = max(prob, 1-prob)
355
+ print(f"Result: {result}")
356
+ print(f"Drug Probability: {prob*100:.2f}%")
357
+ print(f"Confidence: {confidence*100:.2f}%")
358
+ else:
359
+ print("Please enter some text.")
360
+
361
+ except KeyboardInterrupt:
362
+ print("\nExiting...")
363
+ break
364
+ except Exception as e:
365
+ print(f"Error: {e}")