ayush2917's picture
Create utils.py
5b1b3f5 verified
# src/utils.py
import seaborn as sns
import matplotlib.pyplot as plt
import logging
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer
def setup_logging():
logging.basicConfig(filename="logs/app.log", level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s")
def plot_confusion_matrix(cm, labels, filename="docs/confusion_matrix.png"):
"""Plot and save confusion matrix."""
setup_logging()
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=labels, yticklabels=labels)
plt.xlabel("Predicted")
plt.ylabel("True")
plt.savefig(filename)
logging.info(f"Confusion matrix saved to {filename}")
def load_model_and_tokenizer(model_path):
"""Load trained DistilBERT model and tokenizer."""
setup_logging()
model = DistilBertForSequenceClassification.from_pretrained(model_path)
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
logging.info(f"Model and tokenizer loaded from {model_path}")
return model, tokenizer