File size: 2,611 Bytes
9bf1d31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from ..base_handler import ModelHandler
from transformers import AutoTokenizer
import torch
import time

class TokenClassificationHandler(ModelHandler):
    def __init__(self, model_name, model_class, quantization_type, test_text):
        super().__init__(model_name, model_class, quantization_type, test_text)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)

    def run_inference(self, model, text):
        inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(self.device)
        start_time = time.time()
        with torch.no_grad():
            outputs = model(**inputs)
        end_time = time.time()
        return outputs, end_time - start_time

    def decode_output(self, model, outputs):
        tokens = self.tokenizer.convert_ids_to_tokens(outputs['input_ids'][0])
        labels = torch.argmax(outputs.logits, dim=-1).squeeze().tolist()
        decoded_labels = [model.config.id2label[label] for label in labels]
        return dict(zip(tokens, decoded_labels))

    def compare_outputs(self, original_outputs, quantized_outputs):
        """Compare outputs for token classification models"""
        if original_outputs is None or quantized_outputs is None:
            return None
        
        orig_logits = original_outputs.logits.cpu().numpy()
        quant_logits = quantized_outputs.logits.cpu().numpy()
        
        orig_preds = orig_logits.argmax(axis=-1)
        quant_preds = quant_logits.argmax(axis=-1)
        
        input_tokens = self.tokenizer.convert_ids_to_tokens(
            self.tokenizer(self.test_text, return_tensors='pt')['input_ids'][0]
        )
        
        orig_labels = [self.original_model.config.id2label[p] for p in orig_preds[0]]
        quant_labels = [self.quantized_model.config.id2label[p] for p in quant_preds[0]]
        
        original_results = list(zip(input_tokens, orig_labels))
        quantized_results = list(zip(input_tokens, quant_labels))
        
        token_matches = sum(o_label == q_label for o_label, q_label in zip(orig_labels, quant_labels))
        total_tokens = len(orig_labels)
        
        metrics = {
            'original_predictions': original_results,
            'quantized_predictions': quantized_results,
            'token_level_accuracy': float(token_matches) / total_tokens if total_tokens > 0 else 0.0,
            'sequence_exact_match': float((orig_preds == quant_preds).all()),
            'logits_mse': ((orig_logits - quant_logits) ** 2).mean(),
        }
        
        return metrics