import torch import torch.nn as nn from transformers import AutoTokenizer, AutoModelForSequenceClassification from typing import List, Dict, Any, Tuple, Optional import numpy as np import json from pathlib import Path import warnings warnings.filterwarnings('ignore') class EnhancedProgressiveAMDModel(nn.Module): """Enhanced AMD model with utterance count awareness""" def __init__(self, model_name: str, utterance_embedding_dim: int = 8): super().__init__() self.bert = AutoModelForSequenceClassification.from_pretrained( model_name, num_labels=1 ) self.utterance_embedding = nn.Embedding(4, utterance_embedding_dim) # 0-3 utterances self.enhanced_classifier = nn.Sequential( nn.Linear(self.bert.config.hidden_size + utterance_embedding_dim, 64), nn.ReLU(), nn.Dropout(0.1), nn.Linear(64, 1) ) def forward(self, input_ids, attention_mask, utterance_count): bert_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) bert_hidden = bert_outputs.logits # Utterance count embedding utt_emb = self.utterance_embedding(utterance_count) # Combine BERT output with utterance embedding combined = torch.cat([bert_hidden, utt_emb], dim=-1) # Enhanced classification logits = self.enhanced_classifier(combined) return logits class ProductionEnhancedAMDClassifier: """Production-ready enhanced AMD classifier with comprehensive features""" def __init__(self, model_path: str, tokenizer_name: str, device: str = 'auto'): if device == 'auto': if torch.backends.mps.is_available(): self.device = torch.device('mps') elif torch.cuda.is_available(): self.device = torch.device('cuda') else: self.device = torch.device('cpu') else: self.device = torch.device(device) # Load tokenizer self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) # Load model self.model = EnhancedProgressiveAMDModel(tokenizer_name) self.model.load_state_dict(torch.load(model_path, map_location=self.device)) self.model.to(self.device) self.model.eval() self.max_length = 128 self.threshold = 0.5 print(f"Enhanced AMD classifier loaded on {self.device}") def extract_user_utterances(self, transcript: List[Dict[str, Any]]) -> List[str]: """Extract user utterances in chronological order""" user_utterances = [] for utterance in transcript: if utterance.get("speaker", "").lower() == "user": content = utterance.get("content", "").strip() if content: user_utterances.append(content) return user_utterances @torch.no_grad() def predict_enhanced(self, transcript: List[Dict[str, Any]]) -> Dict[str, Any]: """Enhanced prediction with utterance count awareness""" user_utterances = self.extract_user_utterances(transcript) if not user_utterances: return { 'prediction': 'Human', 'machine_probability': 0.0, 'confidence': 0.5, 'utterance_count': 0, 'available_utterances': 0, 'text_preview': '', 'reasoning': 'No user utterances found' } # Combine up to 3 utterances combined_text = " ".join(user_utterances[:3]) utterance_count = min(len(user_utterances), 3) # Tokenize encoding = self.tokenizer( combined_text, add_special_tokens=True, max_length=self.max_length, padding='max_length', truncation=True, return_attention_mask=True, return_tensors='pt' ) input_ids = encoding['input_ids'].to(self.device) attention_mask = encoding['attention_mask'].to(self.device) utterance_count_tensor = torch.tensor([utterance_count], dtype=torch.long).to(self.device) # Predict logits = self.model(input_ids, attention_mask, utterance_count_tensor) machine_prob = torch.sigmoid(logits).item() prediction = 'Machine' if machine_prob >= self.threshold else 'Human' confidence = max(machine_prob, 1 - machine_prob) return { 'prediction': prediction, 'machine_probability': machine_prob, 'confidence': confidence, 'utterance_count': utterance_count, 'available_utterances': len(user_utterances), 'text_preview': combined_text[:100] + ('...' if len(combined_text) > 100 else ''), 'reasoning': f'Processed {utterance_count} utterances with {confidence:.3f} confidence' } def predict_progressive(self, utterances: List[str], stage_thresholds: List[float] = [0.95, 0.85, 0.75]) -> Dict[str, Any]: """ Progressive utterance analysis for production AMD system """ results = { 'final_decision': False, 'confidence': 0.0, 'decision_stage': 0, 'stage_results': [], 'utterances_processed': 0, 'prediction': 'Human', 'reasoning': '' } for stage, utterance_count in enumerate([1, 2, 3], 1): if len(utterances) < utterance_count: break # Combine utterances up to current stage combined_text = " ".join(utterances[:utterance_count]) # Get prediction transcript = [{"speaker": "user", "content": combined_text}] result = self.predict_enhanced(transcript) stage_result = { 'stage': stage, 'utterances': utterance_count, 'confidence': result['confidence'], 'machine_probability': result['machine_probability'], 'text': combined_text[:100] + '...' if len(combined_text) > 100 else combined_text } results['stage_results'].append(stage_result) results['utterances_processed'] = utterance_count # Check if confidence meets threshold for this stage if stage <= len(stage_thresholds) and result['confidence'] >= stage_thresholds[stage-1]: results['final_decision'] = result['prediction'] == 'Machine' results['confidence'] = result['confidence'] results['decision_stage'] = stage results['prediction'] = result['prediction'] results['reasoning'] = f'Decision made at stage {stage} with {result["confidence"]:.3f} confidence' break # Final stage - make decision regardless of confidence if stage == 3: results['final_decision'] = result['prediction'] == 'Machine' results['confidence'] = result['confidence'] results['decision_stage'] = stage results['prediction'] = result['prediction'] results['reasoning'] = f'Final decision at stage {stage} with {result["confidence"]:.3f} confidence' return results def batch_predict(self, transcripts: List[List[Dict[str, Any]]]) -> List[Dict[str, Any]]: """Batch prediction for multiple transcripts""" results = [] for transcript in transcripts: result = self.predict_enhanced(transcript) results.append(result) return results def get_model_info(self) -> Dict[str, Any]: """Get model information and statistics""" total_params = sum(p.numel() for p in self.model.parameters()) trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) return { 'model_name': 'Enhanced Progressive AMD Classifier', 'device': str(self.device), 'total_parameters': total_params, 'trainable_parameters': trainable_params, 'max_length': self.max_length, 'threshold': self.threshold, 'tokenizer_name': self.tokenizer.name_or_path, 'vocab_size': self.tokenizer.vocab_size } # Usage examples and testing functions def test_production_classifier(): """Test the production classifier with sample data""" # Initialize classifier classifier = ProductionEnhancedAMDClassifier( model_path='output/best_enhanced_progressive_amd.pth', tokenizer_name='prajjwal1/bert-tiny' ) # Test cases test_cases = [ # Human responses { 'name': 'Single Human Utterance', 'transcript': [{"speaker": "user", "content": "Yes, I'm here. What do you need?"}] }, { 'name': 'Multi Human Utterances', 'transcript': [ {"speaker": "user", "content": "Hello?"}, {"speaker": "user", "content": "Yes, this is John speaking."}, {"speaker": "user", "content": "How can I help you?"} ] }, # Machine responses { 'name': 'Voicemail Message', 'transcript': [{"speaker": "user", "content": "Hi, you've reached John's voicemail. I'm not available right now, but please leave your name, number, and a brief message after the beep."}] }, { 'name': 'Automated Response', 'transcript': [ {"speaker": "user", "content": "The person you are trying to reach is not available."}, {"speaker": "user", "content": "Please leave a message after the tone."} ] } ] print("Testing Production Enhanced AMD Classifier") print("=" * 60) for test_case in test_cases: print(f" Test: {test_case['name']}") result = classifier.predict_enhanced(test_case['transcript']) print(f" Prediction: {result['prediction']}") print(f" Machine Probability: {result['machine_probability']:.4f}") print(f" Confidence: {result['confidence']:.4f}") print(f" Utterance Count: {result['utterance_count']}") print(f" Text Preview: {result['text_preview']}") print(f" Reasoning: {result['reasoning']}") return classifier if __name__ == "__main__": # Run tests test_production_classifier()