import pandas as pd from sklearn.model_selection import train_test_split from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification, Trainer, TrainingArguments import torch from sklearn.metrics import accuracy_score, precision_recall_fscore_support, classification_report import logging import os import numpy as np from sklearn.utils.class_weight import compute_class_weight # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # We'll now primarily use your existing 1500-entry dataset # Synthetic examples are only used for targeted augmentation if needed def load_and_augment_data(csv_path="train.csv"): """Load your existing 1500-entry dataset with optional targeted augmentation""" if not os.path.exists(csv_path): raise FileNotFoundError(f"Training data file not found: {csv_path}") # Load your existing dataset df = pd.read_csv(csv_path) logger.info(f"Loaded {len(df)} examples from {csv_path}") # Validate columns if not {'text', 'label'}.issubset(df.columns): raise ValueError("train.csv must contain 'text' and 'label' columns") # Check for invalid or missing data initial_len = len(df) df = df.dropna(subset=['text', 'label']) if len(df) < initial_len: logger.warning(f"Dropped {initial_len - len(df)} rows with missing data") # Clean and validate labels - handle various possible formats df['label'] = df['label'].astype(str).str.upper().str.strip() # Map common label variations to standard format label_mapping = { 'DRUG': 'DRUG', 'NON_DRUG': 'NON_DRUG', 'NON-DRUG': 'NON_DRUG', 'NONDRUG': 'NON_DRUG', 'NOT_DRUG': 'NON_DRUG', 'NO_DRUG': 'NON_DRUG', '1': 'DRUG', '0': 'NON_DRUG', 'TRUE': 'DRUG', 'FALSE': 'NON_DRUG', 'YES': 'DRUG', 'NO': 'NON_DRUG' } # Apply label mapping df['label'] = df['label'].map(label_mapping).fillna(df['label']) # Check for any remaining invalid labels valid_labels = ['DRUG', 'NON_DRUG'] invalid_mask = ~df['label'].isin(valid_labels) if invalid_mask.any(): invalid_labels = df.loc[invalid_mask, 'label'].unique() logger.warning(f"Found {invalid_mask.sum()} rows with invalid labels: {invalid_labels}") logger.warning("These will be dropped. Valid labels are: DRUG, NON_DRUG") df = df[~invalid_mask] # Analyze your dataset balance label_counts = df['label'].value_counts() drug_count = label_counts.get("DRUG", 0) non_drug_count = label_counts.get("NON_DRUG", 0) drug_ratio = drug_count / len(df) if len(df) > 0 else 0 logger.info(f"Your dataset analysis:") logger.info(f" Total examples: {len(df)}") logger.info(f" DRUG examples: {drug_count} ({drug_ratio:.1%})") logger.info(f" NON_DRUG examples: {non_drug_count} ({(1-drug_ratio):.1%})") # Check if we need targeted augmentation need_augmentation = False augmentation_reason = [] if drug_ratio < 0.2: # Less than 20% drug examples need_augmentation = True augmentation_reason.append(f"low DRUG ratio ({drug_ratio:.1%})") if drug_count < 100: # Less than 100 drug examples need_augmentation = True augmentation_reason.append(f"low DRUG count ({drug_count})") # Optional targeted augmentation for specific missing patterns if need_augmentation: logger.info(f"Dataset needs augmentation due to: {', '.join(augmentation_reason)}") logger.info("Adding targeted synthetic examples to improve model robustness...") # Add only the most critical synthetic examples that might be missing critical_drug_examples = [ {"text": "Bro, check the Insta DM. That the white or the blue?", "label": "DRUG"}, {"text": "White, straight from Mumbai. Cool, payment through crypto, right?", "label": "DRUG"}, {"text": "Who's bringing the stuff? Raj, Tabs, Weed and Coke.", "label": "DRUG"}, {"text": "Let's not overdose this time.", "label": "DRUG"}, {"text": "Saturday Rave is confirmed, right? Yes, outskirts near Kanaka Pura.", "label": "DRUG"}, {"text": "Got the hash and charas ready for pickup.", "label": "DRUG"}, {"text": "Quality MDMA and LSD tabs available.", "label": "DRUG"}, {"text": "Syringe and needle for the gear.", "label": "DRUG"}, {"text": "Trip was amazing, need more powder.", "label": "DRUG"}, {"text": "Package delivery confirmed, bring crypto payment.", "label": "DRUG"}, ] synthetic_df = pd.DataFrame(critical_drug_examples) df = pd.concat([df, synthetic_df], ignore_index=True) logger.info(f"Added {len(critical_drug_examples)} targeted synthetic DRUG examples") else: logger.info("Dataset appears well-balanced, using your original data without augmentation") # Final statistics final_counts = df['label'].value_counts() final_drug_ratio = final_counts.get("DRUG", 0) / len(df) logger.info(f"Final dataset: {len(df)} examples") logger.info(f"Final DRUG ratio: {final_drug_ratio:.1%} ({final_counts.get('DRUG', 0)} examples)") logger.info(f"Final NON_DRUG ratio: {(1-final_drug_ratio):.1%} ({final_counts.get('NON_DRUG', 0)} examples)") return df # Custom weighted loss for class imbalance class WeightedTrainer(Trainer): def __init__(self, class_weights=None, *args, **kwargs): super().__init__(*args, **kwargs) self.class_weights = class_weights def compute_loss(self, model, inputs, return_outputs=False, **kwargs): labels = inputs.get("labels") outputs = model(**inputs) logits = outputs.get("logits") if self.class_weights is not None: weight_tensor = torch.tensor(self.class_weights, device=labels.device, dtype=torch.float) loss_fct = torch.nn.CrossEntropyLoss(weight=weight_tensor) loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1)) else: loss = outputs.loss return (loss, outputs) if return_outputs else loss def main(): # Load and prepare data df = load_and_augment_data() # Encode labels label2id = {"NON_DRUG": 0, "DRUG": 1} id2label = {0: "NON_DRUG", 1: "DRUG"} df['label_id'] = df['label'].map(label2id) # Compute class weights for imbalanced data class_weights = compute_class_weight( 'balanced', classes=np.unique(df['label_id']), y=df['label_id'] ) logger.info(f"Computed class weights: NON_DRUG={class_weights[0]:.3f}, DRUG={class_weights[1]:.3f}") # Split dataset with stratification train_texts, val_texts, train_labels, val_labels = train_test_split( df['text'].tolist(), df['label_id'].tolist(), test_size=0.2, random_state=42, stratify=df['label_id'] # Ensure balanced split ) logger.info(f"Training set: {len(train_texts)} samples") logger.info(f"Validation set: {len(val_texts)} samples") # Check balance in splits train_drug_ratio = sum(train_labels) / len(train_labels) val_drug_ratio = sum(val_labels) / len(val_labels) logger.info(f"Train DRUG ratio: {train_drug_ratio:.2%}") logger.info(f"Validation DRUG ratio: {val_drug_ratio:.2%}") # Tokenizer tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased') # Tokenize with appropriate max length max_length = 256 # Reduced for efficiency, most drug conversations are short train_encodings = tokenizer(train_texts, truncation=True, padding=True, max_length=max_length) val_encodings = tokenizer(val_texts, truncation=True, padding=True, max_length=max_length) # Dataset class class DrugDataset(torch.utils.data.Dataset): def __init__(self, encodings, labels): self.encodings = encodings self.labels = labels def __getitem__(self, idx): item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()} item['labels'] = torch.tensor(self.labels[idx], dtype=torch.long) return item def __len__(self): return len(self.labels) train_dataset = DrugDataset(train_encodings, train_labels) val_dataset = DrugDataset(val_encodings, val_labels) # Load model model = DistilBertForSequenceClassification.from_pretrained( 'distilbert-base-uncased', num_labels=2, id2label=id2label, label2id=label2id ) # Enhanced training arguments optimized for your 1500-entry dataset training_args = TrainingArguments( output_dir='./results', num_train_epochs=6, # Good balance for 1500 examples per_device_train_batch_size=8, # Suitable for most GPUs per_device_eval_batch_size=16, # Larger batch for evaluation eval_strategy="epoch", save_strategy="epoch", logging_dir='./logs', logging_steps=10, # Log every 10 steps learning_rate=2e-5, # Standard DistilBERT learning rate weight_decay=0.01, # L2 regularization warmup_steps=len(train_dataset) // 10, # 10% of training steps load_best_model_at_end=True, metric_for_best_model="f1", greater_is_better=True, save_total_limit=3, seed=42, # For reproducibility fp16=torch.cuda.is_available(), # Use mixed precision if GPU available dataloader_drop_last=False, report_to=None, # Disable wandb/tensorboard logging ) # Enhanced metrics computation def compute_metrics(pred): labels = pred.label_ids preds = pred.predictions.argmax(-1) # Detailed metrics precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average=None) precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(labels, preds, average='macro') precision_weighted, recall_weighted, f1_weighted, _ = precision_recall_fscore_support(labels, preds, average='weighted') acc = accuracy_score(labels, preds) # Per-class metrics logger.info(f"Eval metrics:") logger.info(f" Accuracy: {acc:.4f}") logger.info(f" NON_DRUG - Precision: {precision[0]:.4f}, Recall: {recall[0]:.4f}, F1: {f1[0]:.4f}") logger.info(f" DRUG - Precision: {precision[1]:.4f}, Recall: {recall[1]:.4f}, F1: {f1[1]:.4f}") logger.info(f" Macro avg - Precision: {precision_macro:.4f}, Recall: {recall_macro:.4f}, F1: {f1_macro:.4f}") # Classification report logger.info("Detailed classification report:") logger.info(f"\n{classification_report(labels, preds, target_names=['NON_DRUG', 'DRUG'])}") return { 'accuracy': acc, 'f1': f1_macro, # Use macro F1 as main metric 'f1_drug': f1[1], # F1 for DRUG class specifically 'precision': precision_macro, 'recall': recall_macro, 'precision_drug': precision[1], 'recall_drug': recall[1], } # Use weighted trainer to handle class imbalance trainer = WeightedTrainer( class_weights=class_weights, model=model, args=training_args, train_dataset=train_dataset, eval_dataset=val_dataset, compute_metrics=compute_metrics ) # Train the model logger.info("Starting training...") trainer.train() # Final evaluation logger.info("Running final evaluation...") eval_results = trainer.evaluate() logger.info(f"Final evaluation results: {eval_results}") # Save model and tokenizer output_dir = "drug_classifier_model" trainer.save_model(output_dir) tokenizer.save_pretrained(output_dir) logger.info(f"Model and tokenizer saved to '{output_dir}'") # Test with sample drug-related text logger.info("Testing model with sample drug-related text...") test_text = "Bro, check the Insta DM. That the white or the blue? White, straight from Mumbai. Cool, payment through crypto, right? Who's bringing the stuff? Raj, Tabs, Weed and Coke. Let's not overdose this time." # Tokenize and predict inputs = tokenizer(test_text, return_tensors="pt", truncation=True, padding=True, max_length=max_length) with torch.no_grad(): outputs = model(**inputs) predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) predicted_class = torch.argmax(predictions, dim=-1).item() drug_probability = predictions[0][1].item() # Probability of DRUG class logger.info(f"Test prediction: {'DRUG' if predicted_class == 1 else 'NON_DRUG'}") logger.info(f"DRUG probability: {drug_probability:.4f} ({drug_probability*100:.2f}%)") logger.info(f"NON_DRUG probability: {1-drug_probability:.4f} ({(1-drug_probability)*100:.2f}%)") if __name__ == "__main__": main()