Spaces:
Runtime error
Runtime error
File size: 2,597 Bytes
66d7e70 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 | import torch
from transformers import (
RobertaForSequenceClassification,
DebertaForSequenceClassification,
RobertaTokenizer,
DebertaTokenizer,
RobertaConfig,
DebertaConfig
)
class EnsembleInference:
def __init__(self, model_path, device='cpu'):
self.device = device
self.roberta_tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
self.deberta_tokenizer = DebertaTokenizer.from_pretrained("microsoft/deberta-base")
self.load_models(model_path)
def load_models(self, path):
state = torch.load(path, map_location=self.device)
roberta_config = RobertaConfig.from_dict(state['model_configs']['roberta_config'])
deberta_config = DebertaConfig.from_dict(state['model_configs']['deberta_config'])
self.roberta_model = RobertaForSequenceClassification(roberta_config).to(self.device)
self.deberta_model = DebertaForSequenceClassification(deberta_config).to(self.device)
self.roberta_model.load_state_dict(state['roberta_state_dict'])
self.deberta_model.load_state_dict(state['deberta_state_dict'])
self.roberta_model.eval()
self.deberta_model.eval()
def predict(self, text):
roberta_inputs = self.roberta_tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512)
deberta_inputs = self.deberta_tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512)
roberta_inputs = {k: v.to(self.device) for k, v in roberta_inputs.items()}
deberta_inputs = {k: v.to(self.device) for k, v in deberta_inputs.items()}
with torch.no_grad():
roberta_logits = self.roberta_model(**roberta_inputs).logits.squeeze()
deberta_logits = self.deberta_model(**deberta_inputs).logits.squeeze()
roberta_prob = torch.sigmoid(roberta_logits).item()
deberta_prob = torch.sigmoid(deberta_logits).item()
avg_prob = (roberta_prob + deberta_prob) / 2
is_ai = avg_prob > 0.5
prediction = "AI generated" if is_ai else "Human written"
roberta_conf = roberta_prob if is_ai else 1 - roberta_prob
deberta_conf = deberta_prob if is_ai else 1 - deberta_prob
avg_conf = avg_prob if is_ai else 1 - avg_prob
return {
'prediction': prediction,
'confidence': f"{avg_conf:.2%}",
'details': {
'roberta_confidence': f"{roberta_conf:.2%}",
'deberta_confidence': f"{deberta_conf:.2%}"
}
}
|