Spaces:
Runtime error
Runtime error
| import torch | |
| from transformers import ( | |
| RobertaForSequenceClassification, | |
| DebertaForSequenceClassification, | |
| RobertaTokenizer, | |
| DebertaTokenizer, | |
| RobertaConfig, | |
| DebertaConfig | |
| ) | |
| class EnsembleInference: | |
| def __init__(self, model_path, device='cpu'): | |
| self.device = device | |
| self.roberta_tokenizer = RobertaTokenizer.from_pretrained("roberta-base") | |
| self.deberta_tokenizer = DebertaTokenizer.from_pretrained("microsoft/deberta-base") | |
| self.load_models(model_path) | |
| def load_models(self, path): | |
| state = torch.load(path, map_location=self.device) | |
| roberta_config = RobertaConfig.from_dict(state['model_configs']['roberta_config']) | |
| deberta_config = DebertaConfig.from_dict(state['model_configs']['deberta_config']) | |
| self.roberta_model = RobertaForSequenceClassification(roberta_config).to(self.device) | |
| self.deberta_model = DebertaForSequenceClassification(deberta_config).to(self.device) | |
| self.roberta_model.load_state_dict(state['roberta_state_dict']) | |
| self.deberta_model.load_state_dict(state['deberta_state_dict']) | |
| self.roberta_model.eval() | |
| self.deberta_model.eval() | |
| def predict(self, text): | |
| roberta_inputs = self.roberta_tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512) | |
| deberta_inputs = self.deberta_tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512) | |
| roberta_inputs = {k: v.to(self.device) for k, v in roberta_inputs.items()} | |
| deberta_inputs = {k: v.to(self.device) for k, v in deberta_inputs.items()} | |
| with torch.no_grad(): | |
| roberta_logits = self.roberta_model(**roberta_inputs).logits.squeeze() | |
| deberta_logits = self.deberta_model(**deberta_inputs).logits.squeeze() | |
| roberta_prob = torch.sigmoid(roberta_logits).item() | |
| deberta_prob = torch.sigmoid(deberta_logits).item() | |
| avg_prob = (roberta_prob + deberta_prob) / 2 | |
| is_ai = avg_prob > 0.5 | |
| prediction = "AI generated" if is_ai else "Human written" | |
| roberta_conf = roberta_prob if is_ai else 1 - roberta_prob | |
| deberta_conf = deberta_prob if is_ai else 1 - deberta_prob | |
| avg_conf = avg_prob if is_ai else 1 - avg_prob | |
| return { | |
| 'prediction': prediction, | |
| 'confidence': f"{avg_conf:.2%}", | |
| 'details': { | |
| 'roberta_confidence': f"{roberta_conf:.2%}", | |
| 'deberta_confidence': f"{deberta_conf:.2%}" | |
| } | |
| } | |