AIDetector / app.py
VictorM-Coder's picture
Update app.py
72d2f9a verified
raw
history blame
7.7 kB
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import re
import math
import pandas as pd
import gradio as gr
# -----------------------------
# MODEL
# -----------------------------
MODEL_NAME = "openai-community/roberta-base-openai-detector"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16 if (device.type=="cuda" and torch.cuda.is_bf16_supported()) else torch.float32
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, dtype=dtype).to(device).eval()
# -----------------------------
# SENTENCE SPLITTER (no lookbehinds)
# Protect → split → restore
# -----------------------------
ABBR = [
"e.g", "i.e", "mr", "mrs", "ms", "dr", "prof", "vs", "etc", "fig", "al",
"jr", "sr", "st", "no", "vol", "pp", "mt", "inc", "ltd", "co", "u.s", "u.k",
"a.m", "p.m"
]
ABBR_REGEX = re.compile(r"\b(" + "|".join(map(re.escape, ABBR)) + r")\.", flags=re.IGNORECASE)
def _protect(text: str) -> str:
t = text.strip()
if not t:
return ""
# Normalize newlines to spaces (Turnitin-like continuous flow)
t = re.sub(r"\s*\n+\s*", " ", t)
# Protect ellipses
t = t.replace("...", "⟨ELLIPSIS⟩")
# Protect decimals like 3.14
t = re.sub(r"(?<=\d)\.(?=\d)", "⟨DECIMAL⟩", t)
# Protect known abbreviations' final dot
t = ABBR_REGEX.sub(r"\1⟨ABBRDOT⟩", t)
return t
def _restore(text: str) -> str:
return (text
.replace("⟨ABBRDOT⟩", ".")
.replace("⟨DECIMAL⟩", ".")
.replace("⟨ELLIPSIS⟩", "..."))
def sentence_split(text: str):
t = _protect(text)
if not t:
return []
# Split on ., ?, ! followed by whitespace and then a plausible sentence starter
# (quote or capital or opening paren) OR end of string.
parts = re.split(r"([.?!])\s+(?=(?:[\"“”‘’']?\s*[A-Z(])|$)", t)
# Rebuild sentences: regex split keeps the delimiter in alternating groups
sentences = []
buf = ""
for i, chunk in enumerate(parts):
if i % 2 == 0:
buf += chunk
else:
# chunk is the delimiter [.?!]
buf += chunk
sentences.append(buf.strip())
buf = ""
if buf.strip():
sentences.append(buf.strip())
# Clean/restore
sentences = [_restore(s).strip() for s in sentences if s.strip()]
return sentences
# -----------------------------
# UTILITIES
# -----------------------------
def batched(iterable, n=64):
for i in range(0, len(iterable), n):
yield iterable[i:i+n], i
def contig_spans(labels):
longest = 0
count = 0
run = 0
for lab in labels:
if lab == "AI":
run += 1
longest = max(longest, run)
else:
if run > 0:
count += 1
run = 0
if run > 0:
count += 1
return count, longest
def verdict_from_stats(flag_pct, longest_span, avg_ai_prob):
if flag_pct >= 85 and longest_span >= 6 and avg_ai_prob >= 0.80:
return "⚠️ Highly likely AI-generated (long consecutive spans and high prevalence)."
if flag_pct >= 60 and longest_span >= 4:
return "⚠️ Strong AI signals (multiple/long spans)."
if flag_pct >= 30 or longest_span >= 3:
return "△ Some AI indicators (partial/short spans)."
return "✓ No clear AI indication (by this detector)."
# -----------------------------
# CORE CLASSIFIER
# -----------------------------
def classify_sentences(text, ai_threshold=0.70, batch_size=64, max_len=512):
sents = sentence_split(text)
if not sents:
return [], [], 0.0, 0.0, (0, 0)
all_probs = []
for chunk, _ in batched(sents, n=batch_size):
inputs = tokenizer(
chunk,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_len
).to(device)
with torch.no_grad():
logits = model(**inputs).logits
probs = F.softmax(logits, dim=-1) # [:,0]=Human, [:,1]=AI
all_probs.extend(probs[:, 1].detach().cpu().tolist())
labels = ["AI" if p >= ai_threshold else "Human" for p in all_probs]
avg_ai_prob = float(sum(all_probs) / len(all_probs))
flagged_pct = 100.0 * sum(1 for l in labels if l == "AI") / len(labels)
spans = contig_spans(labels)
rows = []
for i, (s, p, lab) in enumerate(zip(sents, all_probs, labels), start=1):
rows.append({
"Sentence #": i,
"Sentence": s,
"AI Probability": round(p, 4),
"Label": lab
})
return sents, rows, avg_ai_prob, flagged_pct, spans
# -----------------------------
# HTML HIGHLIGHT
# -----------------------------
def color_for_prob(p):
if p < 0.30: return "#11823b" # green
if p < 0.70: return "#b8860b" # amber
return "#b80d0d" # red
def build_highlight_html(rows):
blocks = []
for r in rows:
p = r["AI Probability"]
col = color_for_prob(p)
pct = f"{p*100:.1f}%"
text = re.sub(r"\s+", " ", r["Sentence"]).strip()
blocks.append(
f"<span style='background:rgba(0,0,0,0.02); "
f"padding:4px 6px; border-radius:6px; display:block; margin:6px 0;'>"
f"<strong style='color:{col}'>[{pct} {r['Label']}]</strong> {text}</span>"
)
return "\n".join(blocks)
# -----------------------------
# PUBLIC API FOR GRADIO
# -----------------------------
def generate_report(text, threshold):
if not text or not text.strip():
return "⚠️ Please enter some text.", None, None, None
sents, rows, avg_ai_prob, flagged_pct, (span_count, longest_span) = classify_sentences(
text, ai_threshold=threshold
)
verdict = verdict_from_stats(flagged_pct, longest_span, avg_ai_prob)
overall = (
f"⚖️ Turnitin-style Summary\n"
f"- Overall AI probability (avg per sentence): {avg_ai_prob*100:.1f}%\n"
f"- Sentences flagged as AI ≥ {int(threshold*100)}%: {flagged_pct:.1f}%\n"
f"- Consecutive AI spans: {span_count} (longest: {longest_span})\n"
f"- Verdict: {verdict}\n"
f"\nⓘ This is an approximation using an open detector; actual Turnitin results may differ."
)
html = build_highlight_html(rows)
df = pd.DataFrame(rows, columns=["Sentence #", "Sentence", "AI Probability", "Label"])
return overall, html, df, f"{flagged_pct:.1f}%"
# -----------------------------
# GRADIO UI
# -----------------------------
with gr.Blocks() as demo:
gr.Markdown("## 🧭 Writenix AI Detector — Turnitin-style (Sentence-Level)")
with gr.Row():
text_input = gr.Textbox(
label="Paste your content",
lines=16,
placeholder="Drop your essay/article here…"
)
with gr.Row():
threshold = gr.Slider(
0.50, 0.95, value=0.70, step=0.01,
label="AI Flag Threshold (probability ≥ threshold ⇒ AI)"
)
detect_btn = gr.Button("🔎 Analyze")
with gr.Row():
ai_summary = gr.Textbox(label="Report Summary", lines=8)
flagged_pct = gr.Label(label="% Sentences Flagged (AI)")
highlighted = gr.HTML(label="Per-Sentence Highlights")
table = gr.Dataframe(headers=["Sentence #", "Sentence", "AI Probability", "Label"], wrap=True)
detect_btn.click(
fn=generate_report,
inputs=[text_input, threshold],
outputs=[ai_summary, highlighted, table, flagged_pct]
)
if __name__ == "__main__":
demo.launch()