mahmoudsaber0 commited on
Commit
dec54ad
·
verified ·
1 Parent(s): db41da3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -71
app.py CHANGED
@@ -1,34 +1,59 @@
1
  import torch
2
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
- import matplotlib.pyplot as plt
4
  import re
 
 
 
 
 
 
5
  import gradio as gr
6
 
7
- # --- Load models ---
 
 
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
 
10
- model_1 = AutoModelForSequenceClassification.from_pretrained("roberta-base-openai-detector").to(device)
11
- model_2 = AutoModelForSequenceClassification.from_pretrained("roberta-large-openai-detector").to(device)
12
- model_3 = AutoModelForSequenceClassification.from_pretrained("Hello-SimpleAI/chatgpt-detector-roberta").to(device)
 
 
 
 
 
 
 
 
 
13
  tokenizer = AutoTokenizer.from_pretrained("roberta-base-openai-detector")
14
 
15
- # --- Label Mapping (example) ---
16
  label_mapping = {i: f"model_{i}" for i in range(25)}
17
  label_mapping[24] = "Human"
18
 
 
 
 
19
  def clean_text(text):
20
- text = re.sub(r'\s+', ' ', text).strip()
21
- return text
22
-
 
 
 
 
 
 
 
 
 
 
23
  def classify_text(text):
24
  cleaned_text = clean_text(text)
25
- if not cleaned_text.strip():
26
- return "", None
27
 
28
- # Split into paragraphs (two newlines)
29
  paragraphs = re.split(r'\n{2,}', cleaned_text)
30
  if len(paragraphs) == 1 and len(cleaned_text.split()) > 300:
31
- # Fallback: split by ~300 words
32
  words = cleaned_text.split()
33
  paragraphs = [' '.join(words[i:i + 300]) for i in range(0, len(words), 300)]
34
 
@@ -37,99 +62,92 @@ def classify_text(text):
37
 
38
  for i, para in enumerate(paragraphs):
39
  inputs = tokenizer(para, return_tensors="pt", truncation=True, padding=True).to(device)
 
40
 
41
- with torch.no_grad():
42
- logits_1 = model_1(**inputs).logits
43
- logits_2 = model_2(**inputs).logits
44
- logits_3 = model_3(**inputs).logits
 
 
45
 
46
- softmax_1 = torch.softmax(logits_1, dim=1)
47
- softmax_2 = torch.softmax(logits_2, dim=1)
48
- softmax_3 = torch.softmax(logits_3, dim=1)
49
 
50
- averaged_probabilities = (softmax_1 + softmax_2 + softmax_3) / 3
51
- probabilities = averaged_probabilities[0]
52
- all_probabilities.append(probabilities.cpu())
53
 
54
  human_prob = probabilities[24].item()
55
- ai_probs_clone = probabilities.clone()
56
- ai_probs_clone[24] = 0
57
- ai_total_prob = ai_probs_clone.sum().item()
58
 
59
- total = human_prob + ai_total_prob
60
  human_pct = (human_prob / total) * 100
61
- ai_pct = (ai_total_prob / total) * 100
62
- ai_index = torch.argmax(ai_probs_clone).item()
63
- ai_model = label_mapping[ai_index]
64
-
65
- short_preview = (para[:180] + "...") if len(para) > 180 else para
66
 
 
67
  paragraph_scores.append({
68
  "id": i + 1,
69
  "human": human_pct,
70
  "ai": ai_pct,
71
  "model": ai_model,
72
- "preview": short_preview
73
  })
74
 
75
- # --- Averages ---
76
  avg_human = sum(p["human"] for p in paragraph_scores) / len(paragraph_scores)
77
  avg_ai = sum(p["ai"] for p in paragraph_scores) / len(paragraph_scores)
78
 
79
  if avg_human > avg_ai:
80
- result_message = f"<b>Overall Result:</b> <span class='highlight-human'>{avg_human:.2f}% Human-written</span>"
81
  else:
82
- top_model = max(paragraph_scores, key=lambda p: p['ai'])['model']
83
- result_message = f"<b>Overall Result:</b> <span class='highlight-ai'>{avg_ai:.2f}% AI-generated (likely {top_model})</span>"
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
- # --- Paragraph Analysis HTML ---
86
- html_output = f"<div style='font-family: Arial, sans-serif; line-height:1.6;'>{result_message}<br><br>"
87
- html_output += "<h3>Paragraph Analysis:</h3>"
88
 
89
  for p in paragraph_scores:
90
  color = "#28a745" if p["human"] > p["ai"] else "#FF5733"
91
- html_output += f"""
92
  <div style='margin-bottom:10px; border-left:5px solid {color}; padding-left:10px; background:#f9f9f9; border-radius:6px;'>
93
  <b>Paragraph {p["id"]}</b>: {p["human"]:.2f}% Human | {p["ai"]:.2f}% AI → <i>{p["model"]}</i><br>
94
  <small>{p["preview"]}</small>
95
  </div>
96
  """
97
 
98
- html_output += "</div>"
99
-
100
- # --- Top 5 Plot ---
101
- mean_probs = torch.mean(torch.stack(all_probabilities), dim=0)
102
- top_5_probs, top_5_indices = torch.topk(mean_probs, 5)
103
- top_5_probs = top_5_probs.cpu().numpy()
104
- top_5_labels = [label_mapping[i.item()] for i in top_5_indices]
105
-
106
- fig, ax = plt.subplots(figsize=(10, 5))
107
- bars = ax.barh(top_5_labels, top_5_probs, color='#4CAF50', alpha=0.8)
108
- ax.set_xlabel('Probability', fontsize=12)
109
- ax.set_title('Top 5 Predictions (Averaged)', fontsize=14, fontweight='bold')
110
- ax.invert_yaxis()
111
- ax.grid(axis='x', linestyle='--', alpha=0.6)
112
- for bar in bars:
113
- width = bar.get_width()
114
- ax.text(width + 0.01, bar.get_y() + bar.get_height() / 2, f'{width:.2%}', va='center')
115
- ax.set_xlim(0, max(top_5_probs) * 1.18)
116
- plt.tight_layout()
117
 
118
- return html_output, fig
119
-
120
- # --- Gradio UI ---
121
  css = """
122
  .highlight-ai { color: #FF5733; font-weight: bold; }
123
  .highlight-human { color: #28a745; font-weight: bold; }
124
  """
125
 
126
  with gr.Blocks(css=css, theme="soft") as demo:
127
- gr.Markdown("# 🧠 AI/Human Text Detector")
128
- text_input = gr.Textbox(label="Paste your text here", lines=12, placeholder="Paste your article or essay...")
129
- output_html = gr.HTML(label="Analysis Results")
130
- output_plot = gr.Plot(label="Top 5 Models")
131
- analyze_btn = gr.Button("🔍 Analyze Text", variant="primary")
132
-
133
- analyze_btn.click(classify_text, inputs=text_input, outputs=[output_html, output_plot])
134
 
135
  demo.launch()
 
1
  import torch
 
 
2
  import re
3
+ import io
4
+ import base64
5
+ import matplotlib
6
+ matplotlib.use("Agg")
7
+ import matplotlib.pyplot as plt
8
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
9
  import gradio as gr
10
 
11
+ # ===============================
12
+ # Safe model loading
13
+ # ===============================
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
 
16
+ def safe_load_model(name):
17
+ try:
18
+ return AutoModelForSequenceClassification.from_pretrained(name).to(device)
19
+ except Exception as e:
20
+ print(f"[WARN] Failed to load {name}: {e}")
21
+ return None
22
+
23
+ print("Loading models...")
24
+
25
+ model_1 = safe_load_model("roberta-base-openai-detector")
26
+ model_2 = safe_load_model("roberta-large-openai-detector")
27
+ model_3 = safe_load_model("Hello-SimpleAI/chatgpt-detector-roberta")
28
  tokenizer = AutoTokenizer.from_pretrained("roberta-base-openai-detector")
29
 
 
30
  label_mapping = {i: f"model_{i}" for i in range(25)}
31
  label_mapping[24] = "Human"
32
 
33
+ # ===============================
34
+ # Helper functions
35
+ # ===============================
36
  def clean_text(text):
37
+ return re.sub(r'\s+', ' ', text).strip()
38
+
39
+ def plot_to_base64(fig):
40
+ buf = io.BytesIO()
41
+ fig.savefig(buf, format="png", bbox_inches="tight")
42
+ buf.seek(0)
43
+ img_base64 = base64.b64encode(buf.getvalue()).decode("utf-8")
44
+ plt.close(fig)
45
+ return f"<img src='data:image/png;base64,{img_base64}' style='max-width:100%; border-radius:8px;'>"
46
+
47
+ # ===============================
48
+ # Main classification logic
49
+ # ===============================
50
  def classify_text(text):
51
  cleaned_text = clean_text(text)
52
+ if not cleaned_text:
53
+ return "<b style='color:red'>Please enter some text.</b>"
54
 
 
55
  paragraphs = re.split(r'\n{2,}', cleaned_text)
56
  if len(paragraphs) == 1 and len(cleaned_text.split()) > 300:
 
57
  words = cleaned_text.split()
58
  paragraphs = [' '.join(words[i:i + 300]) for i in range(0, len(words), 300)]
59
 
 
62
 
63
  for i, para in enumerate(paragraphs):
64
  inputs = tokenizer(para, return_tensors="pt", truncation=True, padding=True).to(device)
65
+ softmax_outputs = []
66
 
67
+ for m in [model_1, model_2, model_3]:
68
+ if m is None:
69
+ continue
70
+ with torch.no_grad():
71
+ logits = m(**inputs).logits
72
+ softmax_outputs.append(torch.softmax(logits, dim=1))
73
 
74
+ if not softmax_outputs:
75
+ return "<b style='color:red'>Error: No models loaded successfully.</b>"
 
76
 
77
+ avg_probs = sum(softmax_outputs) / len(softmax_outputs)
78
+ probabilities = avg_probs[0]
79
+ all_probabilities.append(probabilities.cpu())
80
 
81
  human_prob = probabilities[24].item()
82
+ ai_probs = probabilities.clone()
83
+ ai_probs[24] = 0
84
+ ai_prob = ai_probs.sum().item()
85
 
86
+ total = human_prob + ai_prob
87
  human_pct = (human_prob / total) * 100
88
+ ai_pct = (ai_prob / total) * 100
89
+ ai_model = label_mapping[torch.argmax(ai_probs).item()]
 
 
 
90
 
91
+ preview = para[:180].strip() + ("..." if len(para) > 180 else "")
92
  paragraph_scores.append({
93
  "id": i + 1,
94
  "human": human_pct,
95
  "ai": ai_pct,
96
  "model": ai_model,
97
+ "preview": preview
98
  })
99
 
 
100
  avg_human = sum(p["human"] for p in paragraph_scores) / len(paragraph_scores)
101
  avg_ai = sum(p["ai"] for p in paragraph_scores) / len(paragraph_scores)
102
 
103
  if avg_human > avg_ai:
104
+ overall = f"<b>Overall Result:</b> <span class='highlight-human'>{avg_human:.2f}% Human-written</span>"
105
  else:
106
+ top_model = max(paragraph_scores, key=lambda p: p["ai"])["model"]
107
+ overall = f"<b>Overall Result:</b> <span class='highlight-ai'>{avg_ai:.2f}% AI-generated (likely {top_model})</span>"
108
+
109
+ # --- Top 5 chart ---
110
+ mean_probs = torch.mean(torch.stack(all_probabilities), dim=0)
111
+ top_5_probs, top_5_indices = torch.topk(mean_probs, 5)
112
+ labels = [label_mapping[i.item()] for i in top_5_indices]
113
+ values = top_5_probs.cpu().numpy()
114
+
115
+ fig, ax = plt.subplots(figsize=(8, 4))
116
+ ax.barh(labels, values, color='#4CAF50')
117
+ ax.set_xlabel("Probability")
118
+ ax.set_title("Top 5 Model Predictions")
119
+ ax.invert_yaxis()
120
+ chart_html = plot_to_base64(fig)
121
 
122
+ # --- Paragraph breakdown ---
123
+ html = f"<div style='font-family:Arial, sans-serif; line-height:1.6'>{overall}<br><br>"
124
+ html += "<h3>Paragraph Analysis:</h3>"
125
 
126
  for p in paragraph_scores:
127
  color = "#28a745" if p["human"] > p["ai"] else "#FF5733"
128
+ html += f"""
129
  <div style='margin-bottom:10px; border-left:5px solid {color}; padding-left:10px; background:#f9f9f9; border-radius:6px;'>
130
  <b>Paragraph {p["id"]}</b>: {p["human"]:.2f}% Human | {p["ai"]:.2f}% AI → <i>{p["model"]}</i><br>
131
  <small>{p["preview"]}</small>
132
  </div>
133
  """
134
 
135
+ html += "<br><h3>Top 5 Models:</h3>" + chart_html + "</div>"
136
+ return html
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
+ # ===============================
139
+ # Gradio UI
140
+ # ===============================
141
  css = """
142
  .highlight-ai { color: #FF5733; font-weight: bold; }
143
  .highlight-human { color: #28a745; font-weight: bold; }
144
  """
145
 
146
  with gr.Blocks(css=css, theme="soft") as demo:
147
+ gr.Markdown("# 🧠 AI vs Human Text Detector")
148
+ txt = gr.Textbox(label="Paste your article", lines=12, placeholder="Enter your full text here...")
149
+ btn = gr.Button("Analyze", variant="primary")
150
+ out = gr.HTML(label="Results", elem_id="result-box")
151
+ btn.click(classify_text, inputs=txt, outputs=out)
 
 
152
 
153
  demo.launch()