|
|
import os |
|
|
import sys |
|
|
import pickle |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
from sklearn.isotonic import IsotonicRegression |
|
|
from sklearn.linear_model import LogisticRegression |
|
|
import warnings |
|
|
warnings.filterwarnings('ignore') |
|
|
|
|
|
|
|
|
if 'numpy._core' not in sys.modules: |
|
|
import numpy.core |
|
|
sys.modules['numpy._core'] = numpy.core |
|
|
sys.modules['numpy._core._multiarray_umath'] = numpy.core._multiarray_umath |
|
|
|
|
|
|
|
|
MODEL_NAME = "microsoft/xtremedistil-l6-h256-uncased" |
|
|
CHECKPOINT_PATH = input("Please enter the path to the BERT model directory: ").strip() |
|
|
CALIBRATOR_FILE = os.path.join(CHECKPOINT_PATH, "calibrators.pkl") |
|
|
MAX_LENGTH = 512 |
|
|
BATCH_SIZE = 16 |
|
|
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TemperatureScaling: |
|
|
def __init__(self): |
|
|
self.temperature = 1.0 |
|
|
def transform(self, logits): |
|
|
return logits / self.temperature |
|
|
|
|
|
class PlattScaling: |
|
|
def __init__(self): |
|
|
self.calibrator = LogisticRegression() |
|
|
self.fitted = False |
|
|
def transform(self, logits): |
|
|
if not self.fitted: |
|
|
raise ValueError("Calibrator not fitted") |
|
|
probs = torch.softmax(torch.tensor(logits), dim=-1).numpy() |
|
|
scores = probs[:, 1].reshape(-1, 1) |
|
|
calibrated_probs = self.calibrator.predict_proba(scores) |
|
|
return calibrated_probs |
|
|
|
|
|
class IsotonicCalibration: |
|
|
def __init__(self): |
|
|
self.calibrator = IsotonicRegression(out_of_bounds='clip') |
|
|
self.fitted = False |
|
|
def transform(self, logits): |
|
|
if not self.fitted: |
|
|
raise ValueError("Calibrator not fitted") |
|
|
probs = torch.softmax(torch.tensor(logits), dim=-1).numpy() |
|
|
scores = probs[:, 1] |
|
|
calibrated_scores = self.calibrator.transform(scores) |
|
|
calibrated_probs = np.zeros((len(scores), 2)) |
|
|
calibrated_probs[:, 1] = calibrated_scores |
|
|
calibrated_probs[:, 0] = 1 - calibrated_scores |
|
|
return calibrated_probs |
|
|
|
|
|
class MixNMatchCalibration: |
|
|
def __init__(self, n_bins=15, bin_strategy='quantile'): |
|
|
self.n_bins = n_bins |
|
|
self.bin_strategy = bin_strategy |
|
|
self.temperature = 1.0 |
|
|
self.bin_boundaries = None |
|
|
self.bin_calibrators = {} |
|
|
self.bin_sample_counts = {} |
|
|
|
|
|
def _get_bin_mask(self, probs, bin_idx): |
|
|
lower = self.bin_boundaries[bin_idx] |
|
|
upper = self.bin_boundaries[bin_idx + 1] |
|
|
if bin_idx == self.n_bins - 1: |
|
|
return (probs >= lower) & (probs <= upper) |
|
|
else: |
|
|
return (probs >= lower) & (probs < upper) |
|
|
|
|
|
def transform(self, logits): |
|
|
scaled_logits = logits / self.temperature |
|
|
probs = torch.softmax(torch.tensor(scaled_logits), dim=-1).numpy() |
|
|
class1_probs = probs[:, 1] |
|
|
calibrated_probs = np.zeros_like(class1_probs) |
|
|
|
|
|
for i in range(self.n_bins): |
|
|
bin_mask = self._get_bin_mask(class1_probs, i) |
|
|
if not np.any(bin_mask): |
|
|
continue |
|
|
bin_probs = class1_probs[bin_mask] |
|
|
if self.bin_calibrators.get(i) is not None: |
|
|
cal_type, cal_data = self.bin_calibrators[i] |
|
|
if cal_type == 'isotonic': |
|
|
calibrated_bin_probs = cal_data.predict(bin_probs) |
|
|
elif cal_type == 'mean': |
|
|
calibrated_bin_probs = bin_probs * cal_data |
|
|
calibrated_probs[bin_mask] = np.clip(calibrated_bin_probs, 0, 1) |
|
|
else: |
|
|
calibrated_probs[bin_mask] = bin_probs |
|
|
|
|
|
result = np.zeros((len(calibrated_probs), 2)) |
|
|
result[:, 1] = calibrated_probs |
|
|
result[:, 0] = 1 - calibrated_probs |
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_model_and_calibrators(): |
|
|
"""Load the model and calibrators""" |
|
|
print(f"Loading BERT model from: {CHECKPOINT_PATH}") |
|
|
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_PATH) |
|
|
model = AutoModelForSequenceClassification.from_pretrained(CHECKPOINT_PATH) |
|
|
model = model.to(DEVICE) |
|
|
model.eval() |
|
|
print("BERT model loaded successfully!") |
|
|
|
|
|
|
|
|
print(f"Loading calibrators from: {CALIBRATOR_FILE}") |
|
|
|
|
|
|
|
|
import sys |
|
|
current_module = sys.modules[__name__] |
|
|
|
|
|
class CompatibleUnpickler(pickle.Unpickler): |
|
|
def find_class(self, module, name): |
|
|
|
|
|
if name in ['TemperatureScaling', 'PlattScaling', 'IsotonicCalibration', 'MixNMatchCalibration']: |
|
|
return getattr(current_module, name) |
|
|
if module == 'numpy._core': |
|
|
module = 'numpy.core' |
|
|
elif module == 'numpy._core._multiarray_umath': |
|
|
module = 'numpy.core._multiarray_umath' |
|
|
return super().find_class(module, name) |
|
|
|
|
|
try: |
|
|
with open(CALIBRATOR_FILE, 'rb') as f: |
|
|
cal_data = CompatibleUnpickler(f).load() |
|
|
except: |
|
|
with open(CALIBRATOR_FILE, 'rb') as f: |
|
|
cal_data = pickle.load(f) |
|
|
|
|
|
|
|
|
calibrator = cal_data['calibrators']['mixnmatch'] |
|
|
|
|
|
print("Using calibration: mixnmatch") |
|
|
|
|
|
return model, tokenizer, calibrator |
|
|
|
|
|
def predict_batch(model, tokenizer, calibrator, texts): |
|
|
"""Make predictions on a batch of texts""" |
|
|
all_logits = [] |
|
|
|
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
for i in range(0, len(texts), BATCH_SIZE): |
|
|
batch_texts = texts[i:i + BATCH_SIZE] |
|
|
|
|
|
encoding = tokenizer( |
|
|
batch_texts, |
|
|
truncation=True, |
|
|
padding=True, |
|
|
max_length=MAX_LENGTH, |
|
|
return_tensors='pt' |
|
|
) |
|
|
|
|
|
input_ids = encoding['input_ids'].to(DEVICE) |
|
|
attention_mask = encoding['attention_mask'].to(DEVICE) |
|
|
|
|
|
outputs = model(input_ids=input_ids, attention_mask=attention_mask) |
|
|
logits = outputs.logits.cpu().numpy() |
|
|
all_logits.append(logits) |
|
|
|
|
|
logits = np.vstack(all_logits) |
|
|
|
|
|
|
|
|
uncalibrated_probs = torch.softmax(torch.tensor(logits), dim=-1).numpy() |
|
|
|
|
|
|
|
|
calibrated_output = calibrator.transform(logits) |
|
|
if len(calibrated_output.shape) == 1: |
|
|
calibrated_probs = np.zeros((len(calibrated_output), 2)) |
|
|
calibrated_probs[:, 1] = calibrated_output |
|
|
calibrated_probs[:, 0] = 1 - calibrated_output |
|
|
else: |
|
|
calibrated_probs = calibrated_output |
|
|
|
|
|
|
|
|
predictions = np.argmax(calibrated_probs, axis=1) |
|
|
confidence = np.max(calibrated_probs, axis=1) |
|
|
|
|
|
|
|
|
cal_conf = np.max(calibrated_probs, axis=1) |
|
|
uncal_conf = np.max(uncalibrated_probs, axis=1) |
|
|
calibration_shift = cal_conf - uncal_conf |
|
|
|
|
|
return { |
|
|
'predictions': predictions, |
|
|
'probabilities': calibrated_probs, |
|
|
'confidence': confidence, |
|
|
'uncalibrated_probs': uncalibrated_probs, |
|
|
'calibration_shift': calibration_shift |
|
|
} |
|
|
|
|
|
def process_texts(texts): |
|
|
"""Process a list of texts and return predictions""" |
|
|
model, tokenizer, calibrator = load_model_and_calibrators() |
|
|
results = predict_batch(model, tokenizer, calibrator, texts) |
|
|
return results |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
sample_texts = [ |
|
|
|
|
|
"URGENT: Your account will be suspended in 2 hours due to suspicious activity. Click this link immediately to verify your identity or lose access forever. -IT Security Team", |
|
|
|
|
|
|
|
|
"Hi, we noticed some unusual login attempts on your account. For your security, please log into your account through our official website when convenient to review your recent activity. If you have concerns, contact our support team at [official number]. -IT Security Team", |
|
|
|
|
|
|
|
|
"Hey! It's Sarah from accounting. I'm working from home and can't access the expense system. Can you quickly send me your login details so I can process your reimbursement today? Thanks!", |
|
|
|
|
|
|
|
|
"Hi, this is Sarah from accounting. I'm having technical issues with the expense system. Could you please submit your reimbursement request through the official portal, or I can walk you through it when I'm back in the office tomorrow?", |
|
|
|
|
|
|
|
|
"I guess you don't really care about our friendship since you never make time for me anymore. I've been there for you through everything, but apparently that doesn't matter. Fine, I'll just stop trying.", |
|
|
|
|
|
|
|
|
"I miss spending time together and I'm feeling a bit disconnected lately. I understand you're busy, but I'd love to catch up when you have some free time. Would you be interested in planning something together?", |
|
|
|
|
|
|
|
|
"You're being way too sensitive about this. You always overreact to everything - I was just joking around. Maybe you should work on not taking things so personally all the time.", |
|
|
|
|
|
|
|
|
"I can see that what I said upset you, and that wasn't my intention. I was trying to be playful, but I can understand how it came across differently. I'm sorry for hurting your feelings." |
|
|
] |
|
|
|
|
|
print("Processing sample texts with BERT model...") |
|
|
results = process_texts(sample_texts) |
|
|
|
|
|
for i, text in enumerate(sample_texts): |
|
|
print(f"\nText: {text}") |
|
|
print(f"Prediction: {results['predictions'][i]}") |
|
|
print(f"Confidence: {results['confidence'][i]:.4f}") |
|
|
print(f"Probabilities: Class 0: {results['probabilities'][i][0]:.4f}, Class 1: {results['probabilities'][i][1]:.4f}") |
|
|
print(f"Calibration Shift: {results['calibration_shift'][i]:.4f}") |