lawlevisan commited on
Commit
51741f3
·
verified ·
1 Parent(s): 64b1d47

Update src/predict.py

Browse files
Files changed (1) hide show
  1. src/predict.py +86 -172
src/predict.py CHANGED
@@ -1,55 +1,53 @@
1
- # predict.py - Fixed version with configuration validation
2
  from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification, DistilBertConfig
3
  import torch
4
  import torch.nn.functional as F
5
  import logging
6
  import os
7
  import json
 
8
 
9
- # Configure logging
 
 
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
12
 
13
- # Global variables for model and tokenizer (loaded once)
 
 
14
  model = None
15
  tokenizer = None
16
  model_loaded = False
17
 
 
 
 
18
  def validate_and_fix_config(model_path):
19
  """Validate and fix model configuration if needed"""
20
  config_path = os.path.join(model_path, "config.json")
21
-
22
  if not os.path.exists(config_path):
23
  logger.warning(f"Config file not found at {config_path}")
24
  return False
25
-
26
  try:
27
  with open(config_path, 'r') as f:
28
  config_data = json.load(f)
29
 
30
- # Check for the problematic configuration
31
  dim = config_data.get('dim', 768)
32
  n_heads = config_data.get('n_heads', 12)
33
-
34
  if dim % n_heads != 0:
35
  logger.warning(f"Configuration issue detected: dim={dim} not divisible by n_heads={n_heads}")
36
-
37
- # Backup original config
38
  backup_path = config_path + ".backup"
39
  if not os.path.exists(backup_path):
40
- import shutil
41
  shutil.copy2(config_path, backup_path)
42
  logger.info(f"Backed up original config to {backup_path}")
43
 
44
- # Fix configuration with standard DistilBERT values
45
  config_data['dim'] = 768
46
  config_data['n_heads'] = 12
47
  config_data['hidden_dim'] = 3072
48
-
49
- # Write fixed config
50
  with open(config_path, 'w') as f:
51
  json.dump(config_data, f, indent=2)
52
-
53
  logger.info("Fixed configuration with standard DistilBERT dimensions")
54
  return True
55
 
@@ -60,37 +58,30 @@ def validate_and_fix_config(model_path):
60
  logger.error(f"Error validating/fixing config: {e}")
61
  return False
62
 
 
 
 
63
  def load_model_with_fallback(model_name):
64
- """Load model with fallback strategies"""
65
  global model, tokenizer
66
-
67
- # Strategy 1: Try loading with config validation
68
  if os.path.exists(model_name):
69
  logger.info(f"Attempting to load local model from {model_name}")
70
-
71
- # Validate and fix config first
72
- config_valid = validate_and_fix_config(model_name)
73
- if not config_valid:
74
- logger.warning("Could not validate/fix config, trying anyway...")
75
-
76
  try:
77
  tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)
78
  model = DistilBertForSequenceClassification.from_pretrained(
79
  model_name,
80
- ignore_mismatched_sizes=True # This helps with dimension issues
81
  )
82
  logger.info("Successfully loaded local model")
83
  return True
84
-
85
  except Exception as e:
86
  logger.error(f"Failed to load local model: {e}")
87
-
88
- # Strategy 2: Create a compatible model with existing weights
89
  if os.path.exists(model_name):
90
  try:
91
  logger.info("Attempting to load with custom configuration...")
92
-
93
- # Create a working config
94
  config = DistilBertConfig(
95
  vocab_size=30522,
96
  max_position_embeddings=512,
@@ -106,14 +97,8 @@ def load_model_with_fallback(model_name):
106
  seq_classif_dropout=0.2,
107
  num_labels=2
108
  )
109
-
110
- # Load tokenizer
111
  tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
112
-
113
- # Create model with fixed config
114
  model = DistilBertForSequenceClassification(config)
115
-
116
- # Try to load existing weights
117
  weights_path = os.path.join(model_name, "pytorch_model.bin")
118
  if os.path.exists(weights_path):
119
  try:
@@ -122,221 +107,150 @@ def load_model_with_fallback(model_name):
122
  logger.info("Loaded existing weights with custom config")
123
  except Exception as weight_error:
124
  logger.warning(f"Could not load weights: {weight_error}")
125
- logger.info("Using randomly initialized model")
126
-
127
  return True
128
-
129
  except Exception as e:
130
  logger.error(f"Custom config loading failed: {e}")
131
 
132
- # Strategy 3: Use pre-trained DistilBERT as fallback
133
  try:
134
  logger.info("Loading fallback model from HuggingFace...")
135
  tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
136
  model = DistilBertForSequenceClassification.from_pretrained(
137
- 'distilbert-base-uncased',
138
- num_labels=2
139
  )
140
- logger.warning("Using pre-trained DistilBERT as fallback - will need retraining for drug classification")
141
  return True
142
-
143
  except Exception as e:
144
  logger.error(f"Fallback model loading failed: {e}")
145
  return False
146
 
 
 
 
147
  def load_model(model_name="drug_classifier_model"):
148
- """Load model and tokenizer once with enhanced error handling"""
149
  global model, tokenizer, model_loaded
150
-
151
  if model_loaded:
152
- return # Already loaded
153
-
154
  try:
155
- # Use the enhanced loading function
156
  success = load_model_with_fallback(model_name)
157
-
158
  if not success:
159
  raise RuntimeError("All model loading strategies failed")
160
-
161
- model.eval() # Set to evaluation mode
162
-
163
- # Move to GPU if available
164
  if torch.cuda.is_available():
165
- model = model.cuda()
166
  logger.info("Model moved to GPU")
167
-
168
  model_loaded = True
169
- logger.info(f"Successfully loaded model and tokenizer")
170
-
171
- # Log model configuration if available
172
- if hasattr(model.config, 'id2label'):
173
- logger.info(f"Model labels: {model.config.id2label}")
174
-
175
  except Exception as e:
176
  logger.error(f"Failed to load model or tokenizer: {e}")
177
  raise
178
 
 
 
 
179
  def predict(text, confidence_threshold=0.5):
180
- """
181
- Predict whether the input text is DRUG (1) or NON_DRUG (0).
182
-
183
- Args:
184
- text (str): Input text to classify
185
- confidence_threshold (float): Threshold for DRUG classification (default: 0.5)
186
-
187
- Returns:
188
- tuple: (label, drug_probability)
189
- - label: 1 for DRUG, 0 for NON_DRUG
190
- - drug_probability: float between 0 and 1
191
- """
192
- # Ensure model is loaded
193
  if not model_loaded:
194
  load_model()
195
-
196
- # Input validation
197
  if not text or not isinstance(text, str):
198
- logger.warning("Empty or invalid input text provided to predict")
199
  return 0, 0.0
200
-
201
  text = text.strip()
202
- if len(text) == 0:
203
- logger.warning("Empty text after stripping provided to predict")
204
  return 0, 0.0
205
-
206
  try:
207
- # Tokenize input - use same max_length as training
208
  inputs = tokenizer(
209
- text,
210
- return_tensors="pt",
211
- truncation=True,
212
- padding=True,
213
- max_length=256 # Match training script
214
  )
215
-
216
- # Move inputs to same device as model
217
  if torch.cuda.is_available() and next(model.parameters()).is_cuda:
218
  inputs = {k: v.cuda() for k, v in inputs.items()}
219
-
220
- # Make prediction
221
  with torch.no_grad():
222
  outputs = model(**inputs)
223
  probs = F.softmax(outputs.logits, dim=-1)
224
-
225
- non_drug_prob = probs[0][0].item() # Probability for NON_DRUG (index 0)
226
- drug_prob = probs[0][1].item() # Probability for DRUG (index 1)
227
-
228
- # Apply threshold for classification
229
  pred_label = 1 if drug_prob > confidence_threshold else 0
230
-
231
- # Log prediction details
232
- logger.info(f"Prediction for: '{text[:100]}{'...' if len(text) > 100 else ''}'")
233
- logger.info(f" Result: {'DRUG' if pred_label == 1 else 'NON_DRUG'}")
234
- logger.info(f" DRUG probability: {drug_prob:.4f} ({drug_prob*100:.2f}%)")
235
- logger.info(f" NON_DRUG probability: {non_drug_prob:.4f} ({non_drug_prob*100:.2f}%)")
236
- logger.info(f" Confidence: {max(drug_prob, non_drug_prob):.4f}")
237
-
238
  return pred_label, drug_prob
239
-
240
  except Exception as e:
241
- logger.error(f"Error during prediction: {e}")
242
- logger.error(f"Input text: {text}")
243
  return 0, 0.0
244
 
 
 
 
245
  def predict_batch(texts, confidence_threshold=0.5):
246
- """
247
- Predict for multiple texts at once (more efficient)
248
-
249
- Args:
250
- texts (list): List of texts to classify
251
- confidence_threshold (float): Threshold for DRUG classification
252
-
253
- Returns:
254
- list: List of (label, drug_probability) tuples
255
- """
256
  if not model_loaded:
257
  load_model()
258
-
259
  if not texts or not isinstance(texts, list):
260
- logger.warning("Empty or invalid text list provided to predict_batch")
261
  return []
262
-
263
- # Filter out empty texts and keep track of original indices
264
- valid_texts = []
265
- valid_indices = []
266
- for i, text in enumerate(texts):
267
- if text and isinstance(text, str) and text.strip():
268
- valid_texts.append(text.strip())
269
  valid_indices.append(i)
270
-
271
  if not valid_texts:
272
- logger.warning("No valid texts found in batch")
273
- return [(0, 0.0)] * len(texts)
274
-
275
  try:
276
- # Tokenize all texts
277
- inputs = tokenizer(
278
- valid_texts,
279
- return_tensors="pt",
280
- truncation=True,
281
- padding=True,
282
- max_length=256
283
- )
284
-
285
- # Move to device
286
  if torch.cuda.is_available() and next(model.parameters()).is_cuda:
287
- inputs = {k: v.cuda() for k, v in inputs.items()}
288
-
289
- # Make predictions
290
  with torch.no_grad():
291
  outputs = model(**inputs)
292
  probs = F.softmax(outputs.logits, dim=-1)
293
-
294
- # Create results array with proper indexing
295
- results = [(0, 0.0)] * len(texts)
296
- for i, valid_idx in enumerate(valid_indices):
297
  drug_prob = probs[i][1].item()
298
  pred_label = 1 if drug_prob > confidence_threshold else 0
299
- results[valid_idx] = (pred_label, drug_prob)
300
-
301
- logger.info(f"Batch prediction completed for {len(valid_texts)} texts")
302
  return results
303
-
304
  except Exception as e:
305
- logger.error(f"Error during batch prediction: {e}")
306
- return [(0, 0.0)] * len(texts)
307
 
 
 
 
308
  def test_predictions():
309
- """Test function to verify model predictions"""
310
  logger.info("Running prediction tests...")
311
-
312
  test_cases = [
313
- # Should be DRUG
314
  "Bro, check the Insta DM. That the white or the blue? White, straight from Mumbai. Cool, payment through crypto, right? Who's bringing the stuff? Raj, Tabs, Weed and Coke. Let's not overdose this time.",
315
  "Got some quality hash and charas ready for pickup tonight",
316
  "MDMA tabs are available, payment through crypto only",
317
-
318
- # Should be NON_DRUG
319
  "Hey, how's your work going today? Let's meet for coffee this evening.",
320
  "The weather is really nice today, perfect for a walk in the park",
321
  "I need to finish my project by tomorrow, can you help me?",
322
  ]
323
-
324
- for i, text in enumerate(test_cases, 1):
325
- logger.info(f"\n--- Test Case {i} ---")
326
  label, prob = predict(text)
327
- expected = "DRUG" if i <= 3 else "NON_DRUG"
328
- actual = "DRUG" if label == 1 else "NON_DRUG"
329
- logger.info(f"Expected: {expected}, Got: {actual}, Probability: {prob:.4f}")
330
-
331
- # ===========================
332
- # HF Spaces / Production Ready
333
- # ===========================
334
 
 
 
 
335
  if __name__ == "__main__":
336
- # Load model once
337
  load_model()
338
-
339
- # Optional: run test predictions to verify setup
340
  test_predictions()
341
 
342
- # No interactive input() calls here – compatible with Hugging Face Spaces
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # predict.py - Production + Interactive Compatible Version
2
  from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification, DistilBertConfig
3
  import torch
4
  import torch.nn.functional as F
5
  import logging
6
  import os
7
  import json
8
+ import shutil
9
 
10
+ # =======================
11
+ # Logging configuration
12
+ # =======================
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger(__name__)
15
 
16
+ # =======================
17
+ # Global variables
18
+ # =======================
19
  model = None
20
  tokenizer = None
21
  model_loaded = False
22
 
23
+ # =======================
24
+ # Config validation/fix
25
+ # =======================
26
  def validate_and_fix_config(model_path):
27
  """Validate and fix model configuration if needed"""
28
  config_path = os.path.join(model_path, "config.json")
 
29
  if not os.path.exists(config_path):
30
  logger.warning(f"Config file not found at {config_path}")
31
  return False
 
32
  try:
33
  with open(config_path, 'r') as f:
34
  config_data = json.load(f)
35
 
 
36
  dim = config_data.get('dim', 768)
37
  n_heads = config_data.get('n_heads', 12)
 
38
  if dim % n_heads != 0:
39
  logger.warning(f"Configuration issue detected: dim={dim} not divisible by n_heads={n_heads}")
 
 
40
  backup_path = config_path + ".backup"
41
  if not os.path.exists(backup_path):
 
42
  shutil.copy2(config_path, backup_path)
43
  logger.info(f"Backed up original config to {backup_path}")
44
 
45
+ # Fix configuration
46
  config_data['dim'] = 768
47
  config_data['n_heads'] = 12
48
  config_data['hidden_dim'] = 3072
 
 
49
  with open(config_path, 'w') as f:
50
  json.dump(config_data, f, indent=2)
 
51
  logger.info("Fixed configuration with standard DistilBERT dimensions")
52
  return True
53
 
 
58
  logger.error(f"Error validating/fixing config: {e}")
59
  return False
60
 
61
+ # =======================
62
+ # Model loading with fallback
63
+ # =======================
64
  def load_model_with_fallback(model_name):
 
65
  global model, tokenizer
66
+ # Strategy 1: Load local model
 
67
  if os.path.exists(model_name):
68
  logger.info(f"Attempting to load local model from {model_name}")
69
+ validate_and_fix_config(model_name)
 
 
 
 
 
70
  try:
71
  tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)
72
  model = DistilBertForSequenceClassification.from_pretrained(
73
  model_name,
74
+ ignore_mismatched_sizes=True
75
  )
76
  logger.info("Successfully loaded local model")
77
  return True
 
78
  except Exception as e:
79
  logger.error(f"Failed to load local model: {e}")
80
+
81
+ # Strategy 2: Custom config + weights
82
  if os.path.exists(model_name):
83
  try:
84
  logger.info("Attempting to load with custom configuration...")
 
 
85
  config = DistilBertConfig(
86
  vocab_size=30522,
87
  max_position_embeddings=512,
 
97
  seq_classif_dropout=0.2,
98
  num_labels=2
99
  )
 
 
100
  tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
 
 
101
  model = DistilBertForSequenceClassification(config)
 
 
102
  weights_path = os.path.join(model_name, "pytorch_model.bin")
103
  if os.path.exists(weights_path):
104
  try:
 
107
  logger.info("Loaded existing weights with custom config")
108
  except Exception as weight_error:
109
  logger.warning(f"Could not load weights: {weight_error}")
 
 
110
  return True
 
111
  except Exception as e:
112
  logger.error(f"Custom config loading failed: {e}")
113
 
114
+ # Strategy 3: HuggingFace fallback
115
  try:
116
  logger.info("Loading fallback model from HuggingFace...")
117
  tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
118
  model = DistilBertForSequenceClassification.from_pretrained(
119
+ 'distilbert-base-uncased', num_labels=2
 
120
  )
121
+ logger.warning("Using pre-trained DistilBERT as fallback - retraining recommended")
122
  return True
 
123
  except Exception as e:
124
  logger.error(f"Fallback model loading failed: {e}")
125
  return False
126
 
127
+ # =======================
128
+ # Load model globally
129
+ # =======================
130
  def load_model(model_name="drug_classifier_model"):
 
131
  global model, tokenizer, model_loaded
 
132
  if model_loaded:
133
+ return
 
134
  try:
 
135
  success = load_model_with_fallback(model_name)
 
136
  if not success:
137
  raise RuntimeError("All model loading strategies failed")
138
+ model.eval()
 
 
 
139
  if torch.cuda.is_available():
140
+ model.cuda()
141
  logger.info("Model moved to GPU")
 
142
  model_loaded = True
143
+ logger.info("Successfully loaded model and tokenizer")
 
 
 
 
 
144
  except Exception as e:
145
  logger.error(f"Failed to load model or tokenizer: {e}")
146
  raise
147
 
148
+ # =======================
149
+ # Single prediction
150
+ # =======================
151
  def predict(text, confidence_threshold=0.5):
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  if not model_loaded:
153
  load_model()
 
 
154
  if not text or not isinstance(text, str):
 
155
  return 0, 0.0
 
156
  text = text.strip()
157
+ if not text:
 
158
  return 0, 0.0
 
159
  try:
 
160
  inputs = tokenizer(
161
+ text, return_tensors="pt", truncation=True, padding=True, max_length=256
 
 
 
 
162
  )
 
 
163
  if torch.cuda.is_available() and next(model.parameters()).is_cuda:
164
  inputs = {k: v.cuda() for k, v in inputs.items()}
 
 
165
  with torch.no_grad():
166
  outputs = model(**inputs)
167
  probs = F.softmax(outputs.logits, dim=-1)
168
+ drug_prob = probs[0][1].item()
 
 
 
 
169
  pred_label = 1 if drug_prob > confidence_threshold else 0
 
 
 
 
 
 
 
 
170
  return pred_label, drug_prob
 
171
  except Exception as e:
172
+ logger.error(f"Prediction error: {e}")
 
173
  return 0, 0.0
174
 
175
+ # =======================
176
+ # Batch prediction
177
+ # =======================
178
  def predict_batch(texts, confidence_threshold=0.5):
 
 
 
 
 
 
 
 
 
 
179
  if not model_loaded:
180
  load_model()
 
181
  if not texts or not isinstance(texts, list):
 
182
  return []
183
+ valid_texts, valid_indices = [], []
184
+ for i, t in enumerate(texts):
185
+ if t and isinstance(t, str) and t.strip():
186
+ valid_texts.append(t.strip())
 
 
 
187
  valid_indices.append(i)
 
188
  if not valid_texts:
189
+ return [(0,0.0)]*len(texts)
 
 
190
  try:
191
+ inputs = tokenizer(valid_texts, return_tensors="pt", truncation=True, padding=True, max_length=256)
 
 
 
 
 
 
 
 
 
192
  if torch.cuda.is_available() and next(model.parameters()).is_cuda:
193
+ inputs = {k: v.cuda() for k,v in inputs.items()}
 
 
194
  with torch.no_grad():
195
  outputs = model(**inputs)
196
  probs = F.softmax(outputs.logits, dim=-1)
197
+ results = [(0,0.0)]*len(texts)
198
+ for i, idx in enumerate(valid_indices):
 
 
199
  drug_prob = probs[i][1].item()
200
  pred_label = 1 if drug_prob > confidence_threshold else 0
201
+ results[idx] = (pred_label, drug_prob)
 
 
202
  return results
 
203
  except Exception as e:
204
+ logger.error(f"Batch prediction error: {e}")
205
+ return [(0,0.0)]*len(texts)
206
 
207
+ # =======================
208
+ # Test predictions
209
+ # =======================
210
  def test_predictions():
 
211
  logger.info("Running prediction tests...")
 
212
  test_cases = [
 
213
  "Bro, check the Insta DM. That the white or the blue? White, straight from Mumbai. Cool, payment through crypto, right? Who's bringing the stuff? Raj, Tabs, Weed and Coke. Let's not overdose this time.",
214
  "Got some quality hash and charas ready for pickup tonight",
215
  "MDMA tabs are available, payment through crypto only",
 
 
216
  "Hey, how's your work going today? Let's meet for coffee this evening.",
217
  "The weather is really nice today, perfect for a walk in the park",
218
  "I need to finish my project by tomorrow, can you help me?",
219
  ]
220
+ for i, text in enumerate(test_cases,1):
 
 
221
  label, prob = predict(text)
222
+ expected = "DRUG" if i<=3 else "NON_DRUG"
223
+ actual = "DRUG" if label==1 else "NON_DRUG"
224
+ logger.info(f"Test {i}: Expected={expected}, Got={actual}, Prob={prob:.4f}")
 
 
 
 
225
 
226
+ # =======================
227
+ # Main interactive loop
228
+ # =======================
229
  if __name__ == "__main__":
 
230
  load_model()
 
 
231
  test_predictions()
232
 
233
+ print("\n" + "="*50)
234
+ print("Interactive Drug Detection")
235
+ print("Type 'quit' to exit")
236
+ print("="*50)
237
+
238
+ while True:
239
+ try:
240
+ user_input = input("\nEnter text: ").strip()
241
+ if user_input.lower() in ['quit','exit','q']:
242
+ break
243
+ if user_input:
244
+ label, prob = predict(user_input)
245
+ result = "🚨 DRUG" if label==1 else "✅ NON_DRUG"
246
+ confidence = max(prob, 1-prob)
247
+ print(f"Result: {result}")
248
+ print(f"Drug Probability: {prob*100:.2f}%")
249
+ print(f"Confidence: {confidence*100:.2f}%")
250
+ else:
251
+ print("Please enter some text.")
252
+ except KeyboardInterrupt:
253
+ print("\nExiting...")
254
+ break
255
+ except Exception as e:
256
+ print(f"Error: {e}")