import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch.nn as nn from typing import List, Dict, Any import numpy as np class EnhancedProgressiveAMDModel(nn.Module): """Enhanced model that incorporates utterance count information""" def __init__(self, base_model_name: str, utterance_embedding_dim: int = 8): super().__init__() # Base BERT model self.bert = AutoModelForSequenceClassification.from_pretrained( base_model_name, num_labels=1 ) # Utterance count embedding self.utterance_count_embedding = nn.Embedding(4, utterance_embedding_dim) # Enhanced classifier bert_hidden_size = self.bert.config.hidden_size self.enhanced_classifier = nn.Sequential( nn.Linear(bert_hidden_size + utterance_embedding_dim, 64), nn.ReLU(), nn.Dropout(0.1), nn.Linear(64, 1) ) self.bert.classifier = nn.Identity() def forward(self, input_ids, attention_mask, utterance_count=None): bert_outputs = self.bert.bert(input_ids=input_ids, attention_mask=attention_mask) pooled_output = bert_outputs.pooler_output if utterance_count is not None: utterance_emb = self.utterance_count_embedding(utterance_count) combined_features = torch.cat([pooled_output, utterance_emb], dim=1) logits = self.enhanced_classifier(combined_features) else: batch_size = pooled_output.size(0) zero_utterance_emb = torch.zeros(batch_size, 8, device=pooled_output.device) combined_features = torch.cat([pooled_output, zero_utterance_emb], dim=1) logits = self.enhanced_classifier(combined_features) return logits class ProductionEnhancedAMDClassifier: """Production-ready enhanced AMD classifier""" def __init__(self, model_path: str, tokenizer_name: str = 'prajjwal1/bert-tiny', device: str = 'auto'): if device == 'auto': if torch.backends.mps.is_available(): self.device = torch.device('mps') elif torch.cuda.is_available(): self.device = torch.device('cuda') else: self.device = torch.device('cpu') else: self.device = torch.device(device) self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) self.model = EnhancedProgressiveAMDModel(tokenizer_name) self.model.load_state_dict(torch.load(model_path, map_location=self.device)) self.model.to(self.device) self.model.eval() self.max_length = 128 self.threshold = 0.5 print(f"Enhanced AMD classifier loaded on {self.device}") def extract_user_utterances(self, transcript: List[Dict[str, Any]]) -> List[str]: user_utterances = [] for utterance in transcript: if utterance.get("speaker", "").lower() == "user": content = utterance.get("content", "").strip() if content: user_utterances.append(content) return user_utterances @torch.no_grad() def predict(self, transcript: List[Dict[str, Any]]) -> Dict[str, Any]: user_utterances = self.extract_user_utterances(transcript) if not user_utterances: return { 'prediction': 'Human', 'machine_probability': 0.0, 'confidence': 0.5, 'utterance_count': 0 } utt1 = user_utterances[0] if len(user_utterances) >= 1 else "" utt2 = user_utterances[1] if len(user_utterances) >= 2 else "" utt3 = user_utterances[2] if len(user_utterances) >= 3 else "" combined_text = " ".join([utt for utt in [utt1, utt2, utt3] if utt.strip()]) utterance_count = min(len(user_utterances), 3) encoding = self.tokenizer( combined_text, add_special_tokens=True, max_length=self.max_length, padding='max_length', truncation=True, return_tensors='pt' ) input_ids = encoding['input_ids'].to(self.device) attention_mask = encoding['attention_mask'].to(self.device) utterance_count_tensor = torch.tensor([utterance_count], dtype=torch.long).to(self.device) logits = self.model( input_ids=input_ids, attention_mask=attention_mask, utterance_count=utterance_count_tensor ) machine_prob = torch.sigmoid(logits.squeeze(-1)).item() prediction = 'Machine' if machine_prob >= self.threshold else 'Human' confidence = max(machine_prob, 1 - machine_prob) return { 'prediction': prediction, 'machine_probability': machine_prob, 'confidence': confidence, 'utterance_count': utterance_count, 'available_utterances': len(user_utterances) } # Usage: # classifier = ProductionEnhancedAMDClassifier('path/to/model.pth') # result = classifier.predict(transcript)