|
|
|
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
import torch.nn as nn |
|
|
from typing import List, Dict, Any |
|
|
import numpy as np |
|
|
|
|
|
class EnhancedProgressiveAMDModel(nn.Module): |
|
|
"""Enhanced model that incorporates utterance count information""" |
|
|
|
|
|
def __init__(self, base_model_name: str, utterance_embedding_dim: int = 8): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.bert = AutoModelForSequenceClassification.from_pretrained( |
|
|
base_model_name, num_labels=1 |
|
|
) |
|
|
|
|
|
|
|
|
self.utterance_count_embedding = nn.Embedding(4, utterance_embedding_dim) |
|
|
|
|
|
|
|
|
bert_hidden_size = self.bert.config.hidden_size |
|
|
self.enhanced_classifier = nn.Sequential( |
|
|
nn.Linear(bert_hidden_size + utterance_embedding_dim, 64), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.1), |
|
|
nn.Linear(64, 1) |
|
|
) |
|
|
|
|
|
self.bert.classifier = nn.Identity() |
|
|
|
|
|
def forward(self, input_ids, attention_mask, utterance_count=None): |
|
|
bert_outputs = self.bert.bert(input_ids=input_ids, attention_mask=attention_mask) |
|
|
pooled_output = bert_outputs.pooler_output |
|
|
|
|
|
if utterance_count is not None: |
|
|
utterance_emb = self.utterance_count_embedding(utterance_count) |
|
|
combined_features = torch.cat([pooled_output, utterance_emb], dim=1) |
|
|
logits = self.enhanced_classifier(combined_features) |
|
|
else: |
|
|
batch_size = pooled_output.size(0) |
|
|
zero_utterance_emb = torch.zeros(batch_size, 8, device=pooled_output.device) |
|
|
combined_features = torch.cat([pooled_output, zero_utterance_emb], dim=1) |
|
|
logits = self.enhanced_classifier(combined_features) |
|
|
|
|
|
return logits |
|
|
|
|
|
class ProductionEnhancedAMDClassifier: |
|
|
"""Production-ready enhanced AMD classifier""" |
|
|
|
|
|
def __init__(self, model_path: str, tokenizer_name: str = 'prajjwal1/bert-tiny', device: str = 'auto'): |
|
|
if device == 'auto': |
|
|
if torch.backends.mps.is_available(): |
|
|
self.device = torch.device('mps') |
|
|
elif torch.cuda.is_available(): |
|
|
self.device = torch.device('cuda') |
|
|
else: |
|
|
self.device = torch.device('cpu') |
|
|
else: |
|
|
self.device = torch.device(device) |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) |
|
|
self.model = EnhancedProgressiveAMDModel(tokenizer_name) |
|
|
self.model.load_state_dict(torch.load(model_path, map_location=self.device)) |
|
|
self.model.to(self.device) |
|
|
self.model.eval() |
|
|
|
|
|
self.max_length = 128 |
|
|
self.threshold = 0.5 |
|
|
|
|
|
print(f"Enhanced AMD classifier loaded on {self.device}") |
|
|
|
|
|
def extract_user_utterances(self, transcript: List[Dict[str, Any]]) -> List[str]: |
|
|
user_utterances = [] |
|
|
for utterance in transcript: |
|
|
if utterance.get("speaker", "").lower() == "user": |
|
|
content = utterance.get("content", "").strip() |
|
|
if content: |
|
|
user_utterances.append(content) |
|
|
return user_utterances |
|
|
|
|
|
@torch.no_grad() |
|
|
def predict(self, transcript: List[Dict[str, Any]]) -> Dict[str, Any]: |
|
|
user_utterances = self.extract_user_utterances(transcript) |
|
|
|
|
|
if not user_utterances: |
|
|
return { |
|
|
'prediction': 'Human', |
|
|
'machine_probability': 0.0, |
|
|
'confidence': 0.5, |
|
|
'utterance_count': 0 |
|
|
} |
|
|
|
|
|
utt1 = user_utterances[0] if len(user_utterances) >= 1 else "" |
|
|
utt2 = user_utterances[1] if len(user_utterances) >= 2 else "" |
|
|
utt3 = user_utterances[2] if len(user_utterances) >= 3 else "" |
|
|
|
|
|
combined_text = " ".join([utt for utt in [utt1, utt2, utt3] if utt.strip()]) |
|
|
utterance_count = min(len(user_utterances), 3) |
|
|
|
|
|
encoding = self.tokenizer( |
|
|
combined_text, |
|
|
add_special_tokens=True, |
|
|
max_length=self.max_length, |
|
|
padding='max_length', |
|
|
truncation=True, |
|
|
return_tensors='pt' |
|
|
) |
|
|
|
|
|
input_ids = encoding['input_ids'].to(self.device) |
|
|
attention_mask = encoding['attention_mask'].to(self.device) |
|
|
utterance_count_tensor = torch.tensor([utterance_count], dtype=torch.long).to(self.device) |
|
|
|
|
|
logits = self.model( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
utterance_count=utterance_count_tensor |
|
|
) |
|
|
|
|
|
machine_prob = torch.sigmoid(logits.squeeze(-1)).item() |
|
|
prediction = 'Machine' if machine_prob >= self.threshold else 'Human' |
|
|
confidence = max(machine_prob, 1 - machine_prob) |
|
|
|
|
|
return { |
|
|
'prediction': prediction, |
|
|
'machine_probability': machine_prob, |
|
|
'confidence': confidence, |
|
|
'utterance_count': utterance_count, |
|
|
'available_utterances': len(user_utterances) |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|