bert-tiny-amd / production_enhanced_amd.py
Adya662's picture
Upload trained BERT-Tiny AMD model
4523f56 verified
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch.nn as nn
from typing import List, Dict, Any
import numpy as np
class EnhancedProgressiveAMDModel(nn.Module):
"""Enhanced model that incorporates utterance count information"""
def __init__(self, base_model_name: str, utterance_embedding_dim: int = 8):
super().__init__()
# Base BERT model
self.bert = AutoModelForSequenceClassification.from_pretrained(
base_model_name, num_labels=1
)
# Utterance count embedding
self.utterance_count_embedding = nn.Embedding(4, utterance_embedding_dim)
# Enhanced classifier
bert_hidden_size = self.bert.config.hidden_size
self.enhanced_classifier = nn.Sequential(
nn.Linear(bert_hidden_size + utterance_embedding_dim, 64),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(64, 1)
)
self.bert.classifier = nn.Identity()
def forward(self, input_ids, attention_mask, utterance_count=None):
bert_outputs = self.bert.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = bert_outputs.pooler_output
if utterance_count is not None:
utterance_emb = self.utterance_count_embedding(utterance_count)
combined_features = torch.cat([pooled_output, utterance_emb], dim=1)
logits = self.enhanced_classifier(combined_features)
else:
batch_size = pooled_output.size(0)
zero_utterance_emb = torch.zeros(batch_size, 8, device=pooled_output.device)
combined_features = torch.cat([pooled_output, zero_utterance_emb], dim=1)
logits = self.enhanced_classifier(combined_features)
return logits
class ProductionEnhancedAMDClassifier:
"""Production-ready enhanced AMD classifier"""
def __init__(self, model_path: str, tokenizer_name: str = 'prajjwal1/bert-tiny', device: str = 'auto'):
if device == 'auto':
if torch.backends.mps.is_available():
self.device = torch.device('mps')
elif torch.cuda.is_available():
self.device = torch.device('cuda')
else:
self.device = torch.device('cpu')
else:
self.device = torch.device(device)
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
self.model = EnhancedProgressiveAMDModel(tokenizer_name)
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
self.model.to(self.device)
self.model.eval()
self.max_length = 128
self.threshold = 0.5
print(f"Enhanced AMD classifier loaded on {self.device}")
def extract_user_utterances(self, transcript: List[Dict[str, Any]]) -> List[str]:
user_utterances = []
for utterance in transcript:
if utterance.get("speaker", "").lower() == "user":
content = utterance.get("content", "").strip()
if content:
user_utterances.append(content)
return user_utterances
@torch.no_grad()
def predict(self, transcript: List[Dict[str, Any]]) -> Dict[str, Any]:
user_utterances = self.extract_user_utterances(transcript)
if not user_utterances:
return {
'prediction': 'Human',
'machine_probability': 0.0,
'confidence': 0.5,
'utterance_count': 0
}
utt1 = user_utterances[0] if len(user_utterances) >= 1 else ""
utt2 = user_utterances[1] if len(user_utterances) >= 2 else ""
utt3 = user_utterances[2] if len(user_utterances) >= 3 else ""
combined_text = " ".join([utt for utt in [utt1, utt2, utt3] if utt.strip()])
utterance_count = min(len(user_utterances), 3)
encoding = self.tokenizer(
combined_text,
add_special_tokens=True,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
input_ids = encoding['input_ids'].to(self.device)
attention_mask = encoding['attention_mask'].to(self.device)
utterance_count_tensor = torch.tensor([utterance_count], dtype=torch.long).to(self.device)
logits = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
utterance_count=utterance_count_tensor
)
machine_prob = torch.sigmoid(logits.squeeze(-1)).item()
prediction = 'Machine' if machine_prob >= self.threshold else 'Human'
confidence = max(machine_prob, 1 - machine_prob)
return {
'prediction': prediction,
'machine_probability': machine_prob,
'confidence': confidence,
'utterance_count': utterance_count,
'available_utterances': len(user_utterances)
}
# Usage:
# classifier = ProductionEnhancedAMDClassifier('path/to/model.pth')
# result = classifier.predict(transcript)