File size: 1,101 Bytes
5b1b3f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# src/utils.py
import seaborn as sns
import matplotlib.pyplot as plt
import logging
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer

def setup_logging():
    logging.basicConfig(filename="logs/app.log", level=logging.INFO, 
                        format="%(asctime)s - %(levelname)s - %(message)s")

def plot_confusion_matrix(cm, labels, filename="docs/confusion_matrix.png"):
    """Plot and save confusion matrix."""
    setup_logging()
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=labels, yticklabels=labels)
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.savefig(filename)
    logging.info(f"Confusion matrix saved to {filename}")

def load_model_and_tokenizer(model_path):
    """Load trained DistilBERT model and tokenizer."""
    setup_logging()
    model = DistilBertForSequenceClassification.from_pretrained(model_path)
    tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
    logging.info(f"Model and tokenizer loaded from {model_path}")
    return model, tokenizer