File size: 5,118 Bytes
4523f56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136

import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch.nn as nn
from typing import List, Dict, Any
import numpy as np

class EnhancedProgressiveAMDModel(nn.Module):
    """Enhanced model that incorporates utterance count information"""

    def __init__(self, base_model_name: str, utterance_embedding_dim: int = 8):
        super().__init__()

        # Base BERT model
        self.bert = AutoModelForSequenceClassification.from_pretrained(
            base_model_name, num_labels=1
        )

        # Utterance count embedding
        self.utterance_count_embedding = nn.Embedding(4, utterance_embedding_dim)

        # Enhanced classifier
        bert_hidden_size = self.bert.config.hidden_size
        self.enhanced_classifier = nn.Sequential(
            nn.Linear(bert_hidden_size + utterance_embedding_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(64, 1)
        )

        self.bert.classifier = nn.Identity()

    def forward(self, input_ids, attention_mask, utterance_count=None):
        bert_outputs = self.bert.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = bert_outputs.pooler_output

        if utterance_count is not None:
            utterance_emb = self.utterance_count_embedding(utterance_count)
            combined_features = torch.cat([pooled_output, utterance_emb], dim=1)
            logits = self.enhanced_classifier(combined_features)
        else:
            batch_size = pooled_output.size(0)
            zero_utterance_emb = torch.zeros(batch_size, 8, device=pooled_output.device)
            combined_features = torch.cat([pooled_output, zero_utterance_emb], dim=1)
            logits = self.enhanced_classifier(combined_features)

        return logits

class ProductionEnhancedAMDClassifier:
    """Production-ready enhanced AMD classifier"""

    def __init__(self, model_path: str, tokenizer_name: str = 'prajjwal1/bert-tiny', device: str = 'auto'):
        if device == 'auto':
            if torch.backends.mps.is_available():
                self.device = torch.device('mps')
            elif torch.cuda.is_available():
                self.device = torch.device('cuda')
            else:
                self.device = torch.device('cpu')
        else:
            self.device = torch.device(device)

        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        self.model = EnhancedProgressiveAMDModel(tokenizer_name)
        self.model.load_state_dict(torch.load(model_path, map_location=self.device))
        self.model.to(self.device)
        self.model.eval()

        self.max_length = 128
        self.threshold = 0.5

        print(f"Enhanced AMD classifier loaded on {self.device}")

    def extract_user_utterances(self, transcript: List[Dict[str, Any]]) -> List[str]:
        user_utterances = []
        for utterance in transcript:
            if utterance.get("speaker", "").lower() == "user":
                content = utterance.get("content", "").strip()
                if content:
                    user_utterances.append(content)
        return user_utterances

    @torch.no_grad()
    def predict(self, transcript: List[Dict[str, Any]]) -> Dict[str, Any]:
        user_utterances = self.extract_user_utterances(transcript)

        if not user_utterances:
            return {
                'prediction': 'Human',
                'machine_probability': 0.0,
                'confidence': 0.5,
                'utterance_count': 0
            }

        utt1 = user_utterances[0] if len(user_utterances) >= 1 else ""
        utt2 = user_utterances[1] if len(user_utterances) >= 2 else ""
        utt3 = user_utterances[2] if len(user_utterances) >= 3 else ""

        combined_text = " ".join([utt for utt in [utt1, utt2, utt3] if utt.strip()])
        utterance_count = min(len(user_utterances), 3)

        encoding = self.tokenizer(
            combined_text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        input_ids = encoding['input_ids'].to(self.device)
        attention_mask = encoding['attention_mask'].to(self.device)
        utterance_count_tensor = torch.tensor([utterance_count], dtype=torch.long).to(self.device)

        logits = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            utterance_count=utterance_count_tensor
        )

        machine_prob = torch.sigmoid(logits.squeeze(-1)).item()
        prediction = 'Machine' if machine_prob >= self.threshold else 'Human'
        confidence = max(machine_prob, 1 - machine_prob)

        return {
            'prediction': prediction,
            'machine_probability': machine_prob,
            'confidence': confidence,
            'utterance_count': utterance_count,
            'available_utterances': len(user_utterances)
        }

# Usage:
# classifier = ProductionEnhancedAMDClassifier('path/to/model.pth')
# result = classifier.predict(transcript)