mihalykiss commited on
Commit
c402f62
·
verified ·
1 Parent(s): d51afa3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -29
app.py CHANGED
@@ -7,33 +7,27 @@ from tokenizers.normalizers import Sequence, Replace, Strip, NFKC
7
  from tokenizers import Regex
8
  import matplotlib.pyplot as plt
9
 
10
- # Set device to GPU if available, otherwise CPU
11
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
 
13
- # --- Model and Tokenizer Setup ---
14
  model1_path = "modernbert.bin"
15
  model2_path = "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed12"
16
  model3_path = "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed22"
17
 
18
  tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")
19
 
20
- # Load Model 1 from local path
21
  model_1 = AutoModelForSequenceClassification.from_pretrained("answerdotai/ModernBERT-base", num_labels=41)
22
  model_1.load_state_dict(torch.load(model1_path, map_location=device))
23
  model_1.to(device).eval()
24
 
25
- # Load Model 2 from URL
26
  model_2 = AutoModelForSequenceClassification.from_pretrained("answerdotai/ModernBERT-base", num_labels=41)
27
  model_2.load_state_dict(torch.hub.load_state_dict_from_url(model2_path, map_location=device))
28
  model_2.to(device).eval()
29
 
30
- # Load Model 3 from URL
31
  model_3 = AutoModelForSequenceClassification.from_pretrained("answerdotai/ModernBERT-base", num_labels=41)
32
  model_3.load_state_dict(torch.hub.load_state_dict_from_url(model3_path, map_location=device))
33
  model_3.to(device).eval()
34
 
35
 
36
- # --- Label Mapping and Text Cleaning ---
37
  label_mapping = {
38
  0: '13B', 1: '30B', 2: '65B', 3: '7B', 4: 'GLM130B', 5: 'bloom_7b',
39
  6: 'bloomz', 7: 'cohere', 8: 'davinci', 9: 'dolly', 10: 'dolly-v2-12b',
@@ -60,22 +54,17 @@ tokenizer.backend_tokenizer.normalizer = Sequence([
60
  newline_to_space,
61
  Strip()
62
  ])
63
-
64
-
65
  def classify_text(text):
66
  """
67
- Classifies the text and generates a plot of the top 5 AI model predictions.
68
  Returns both the result message and the plot figure.
69
  """
70
  cleaned_text = clean_text(text)
71
- # If input is empty, clear the outputs
72
  if not cleaned_text.strip():
73
  return "", None
74
 
75
- # Tokenize input and move to the appropriate device
76
  inputs = tokenizer(cleaned_text, return_tensors="pt", truncation=True, padding=True).to(device)
77
 
78
- # Perform inference with the three models
79
  with torch.no_grad():
80
  logits_1 = model_1(**inputs).logits
81
  logits_2 = model_2(**inputs).logits
@@ -110,34 +99,30 @@ def classify_text(text):
110
  f"**Identified LLM: {ai_argmax_model}**"
111
  )
112
 
 
 
113
 
114
- ai_probs_for_plot = probabilities.clone()
115
- top_5_probs, top_5_indices = torch.topk(ai_probs_for_plot, 5)
116
 
117
- top_5_probs = top_5_probs.cpu().numpy()
118
- top_5_labels = [label_mapping[i.item()] for i in top_5_indices]
 
 
119
 
120
- fig, ax = plt.subplots(figsize=(10, 5))
121
- bars = ax.barh(top_5_labels, top_5_probs, color='#4CAF50', alpha=0.8)
122
- ax.set_xlabel('Probability', fontsize=12)
123
- ax.set_title('Top 5 Predictions', fontsize=14, fontweight='bold')
124
- ax.invert_yaxis()
125
- ax.grid(axis='x', linestyle='--', alpha=0.6)
126
-
127
-
128
  for bar in bars:
129
- width = bar.get_width()
130
- label_x_pos = width + 0.01
131
- ax.text(label_x_pos, bar.get_y() + bar.get_height() / 2, f'{width:.2%}', va='center')
132
 
133
- ax.set_xlim(0, max(top_5_probs) * 1.18)
134
  plt.tight_layout()
135
 
136
-
137
  return result_message, fig
138
 
139
 
140
 
 
141
  title = "AI Text Detector"
142
 
143
  description = """
 
7
  from tokenizers import Regex
8
  import matplotlib.pyplot as plt
9
 
 
10
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
 
 
12
  model1_path = "modernbert.bin"
13
  model2_path = "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed12"
14
  model3_path = "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed22"
15
 
16
  tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")
17
 
 
18
  model_1 = AutoModelForSequenceClassification.from_pretrained("answerdotai/ModernBERT-base", num_labels=41)
19
  model_1.load_state_dict(torch.load(model1_path, map_location=device))
20
  model_1.to(device).eval()
21
 
 
22
  model_2 = AutoModelForSequenceClassification.from_pretrained("answerdotai/ModernBERT-base", num_labels=41)
23
  model_2.load_state_dict(torch.hub.load_state_dict_from_url(model2_path, map_location=device))
24
  model_2.to(device).eval()
25
 
 
26
  model_3 = AutoModelForSequenceClassification.from_pretrained("answerdotai/ModernBERT-base", num_labels=41)
27
  model_3.load_state_dict(torch.hub.load_state_dict_from_url(model3_path, map_location=device))
28
  model_3.to(device).eval()
29
 
30
 
 
31
  label_mapping = {
32
  0: '13B', 1: '30B', 2: '65B', 3: '7B', 4: 'GLM130B', 5: 'bloom_7b',
33
  6: 'bloomz', 7: 'cohere', 8: 'davinci', 9: 'dolly', 10: 'dolly-v2-12b',
 
54
  newline_to_space,
55
  Strip()
56
  ])
 
 
57
  def classify_text(text):
58
  """
59
+ Classifies the text and generates a plot of the human vs AI probability.
60
  Returns both the result message and the plot figure.
61
  """
62
  cleaned_text = clean_text(text)
 
63
  if not cleaned_text.strip():
64
  return "", None
65
 
 
66
  inputs = tokenizer(cleaned_text, return_tensors="pt", truncation=True, padding=True).to(device)
67
 
 
68
  with torch.no_grad():
69
  logits_1 = model_1(**inputs).logits
70
  logits_2 = model_2(**inputs).logits
 
99
  f"**Identified LLM: {ai_argmax_model}**"
100
  )
101
 
102
+ # Create a plot for Human vs AI probabilities
103
+ fig, ax = plt.subplots(figsize=(6, 3))
104
 
105
+ categories = ['Human', 'AI']
106
+ probabilities_for_plot = [human_percentage, ai_percentage]
107
 
108
+ bars = ax.bar(categories, probabilities_for_plot, color=['#4CAF50', '#FF5733'], alpha=0.8)
109
+ ax.set_ylabel('Probability (%)', fontsize=12)
110
+ ax.set_title('Human vs AI Probability', fontsize=14, fontweight='bold')
111
+ ax.grid(axis='y', linestyle='--', alpha=0.6)
112
 
113
+ # Add labels to the bars
 
 
 
 
 
 
 
114
  for bar in bars:
115
+ height = bar.get_height()
116
+ ax.text(bar.get_x() + bar.get_width() / 2, height + 1, f'{height:.2f}%', ha='center')
 
117
 
118
+ ax.set_ylim(0, 100)
119
  plt.tight_layout()
120
 
 
121
  return result_message, fig
122
 
123
 
124
 
125
+
126
  title = "AI Text Detector"
127
 
128
  description = """