Mahmoud59 commited on
Commit
4fc5f42
·
verified ·
1 Parent(s): e527253

Create inference.py

Browse files
Files changed (1) hide show
  1. inference.py +62 -0
inference.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import (
3
+ RobertaForSequenceClassification,
4
+ DebertaForSequenceClassification,
5
+ RobertaTokenizer,
6
+ DebertaTokenizer,
7
+ RobertaConfig,
8
+ DebertaConfig
9
+ )
10
+
11
+ class EnsembleInference:
12
+ def __init__(self, model_path, device='cpu'):
13
+ self.device = device
14
+ self.roberta_tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
15
+ self.deberta_tokenizer = DebertaTokenizer.from_pretrained("microsoft/deberta-base")
16
+ self.load_models(model_path)
17
+
18
+ def load_models(self, path):
19
+ state = torch.load(path, map_location=self.device)
20
+
21
+ roberta_config = RobertaConfig.from_dict(state['model_configs']['roberta_config'])
22
+ deberta_config = DebertaConfig.from_dict(state['model_configs']['deberta_config'])
23
+
24
+ self.roberta_model = RobertaForSequenceClassification(roberta_config).to(self.device)
25
+ self.deberta_model = DebertaForSequenceClassification(deberta_config).to(self.device)
26
+
27
+ self.roberta_model.load_state_dict(state['roberta_state_dict'])
28
+ self.deberta_model.load_state_dict(state['deberta_state_dict'])
29
+
30
+ self.roberta_model.eval()
31
+ self.deberta_model.eval()
32
+
33
+ def predict(self, text):
34
+ roberta_inputs = self.roberta_tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512)
35
+ deberta_inputs = self.deberta_tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512)
36
+
37
+ roberta_inputs = {k: v.to(self.device) for k, v in roberta_inputs.items()}
38
+ deberta_inputs = {k: v.to(self.device) for k, v in deberta_inputs.items()}
39
+
40
+ with torch.no_grad():
41
+ roberta_logits = self.roberta_model(**roberta_inputs).logits.squeeze()
42
+ deberta_logits = self.deberta_model(**deberta_inputs).logits.squeeze()
43
+
44
+ roberta_prob = torch.sigmoid(roberta_logits).item()
45
+ deberta_prob = torch.sigmoid(deberta_logits).item()
46
+
47
+ avg_prob = (roberta_prob + deberta_prob) / 2
48
+ is_ai = avg_prob > 0.5
49
+ prediction = "AI generated" if is_ai else "Human written"
50
+
51
+ roberta_conf = roberta_prob if is_ai else 1 - roberta_prob
52
+ deberta_conf = deberta_prob if is_ai else 1 - deberta_prob
53
+ avg_conf = avg_prob if is_ai else 1 - avg_prob
54
+
55
+ return {
56
+ 'prediction': prediction,
57
+ 'confidence': f"{avg_conf:.2%}",
58
+ 'details': {
59
+ 'roberta_confidence': f"{roberta_conf:.2%}",
60
+ 'deberta_confidence': f"{deberta_conf:.2%}"
61
+ }
62
+ }