import torch from transformers import ( RobertaForSequenceClassification, DebertaForSequenceClassification, RobertaTokenizer, DebertaTokenizer, RobertaConfig, DebertaConfig ) class EnsembleInference: def __init__(self, model_path, device='cpu'): self.device = device self.roberta_tokenizer = RobertaTokenizer.from_pretrained("roberta-base") self.deberta_tokenizer = DebertaTokenizer.from_pretrained("microsoft/deberta-base") self.load_models(model_path) def load_models(self, path): state = torch.load(path, map_location=self.device) roberta_config = RobertaConfig.from_dict(state['model_configs']['roberta_config']) deberta_config = DebertaConfig.from_dict(state['model_configs']['deberta_config']) self.roberta_model = RobertaForSequenceClassification(roberta_config).to(self.device) self.deberta_model = DebertaForSequenceClassification(deberta_config).to(self.device) self.roberta_model.load_state_dict(state['roberta_state_dict']) self.deberta_model.load_state_dict(state['deberta_state_dict']) self.roberta_model.eval() self.deberta_model.eval() def predict(self, text): roberta_inputs = self.roberta_tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512) deberta_inputs = self.deberta_tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512) roberta_inputs = {k: v.to(self.device) for k, v in roberta_inputs.items()} deberta_inputs = {k: v.to(self.device) for k, v in deberta_inputs.items()} with torch.no_grad(): roberta_logits = self.roberta_model(**roberta_inputs).logits.squeeze() deberta_logits = self.deberta_model(**deberta_inputs).logits.squeeze() roberta_prob = torch.sigmoid(roberta_logits).item() deberta_prob = torch.sigmoid(deberta_logits).item() avg_prob = (roberta_prob + deberta_prob) / 2 is_ai = avg_prob > 0.5 prediction = "AI generated" if is_ai else "Human written" roberta_conf = roberta_prob if is_ai else 1 - roberta_prob deberta_conf = deberta_prob if is_ai else 1 - deberta_prob avg_conf = avg_prob if is_ai else 1 - avg_prob return { 'prediction': prediction, 'confidence': f"{avg_conf:.2%}", 'details': { 'roberta_confidence': f"{roberta_conf:.2%}", 'deberta_confidence': f"{deberta_conf:.2%}" } }