Spaces:
Running
Running
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import re | |
| import math | |
| import pandas as pd | |
| import gradio as gr | |
| # ----------------------------- | |
| # MODEL | |
| # ----------------------------- | |
| MODEL_NAME = "openai-community/roberta-base-openai-detector" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| dtype = torch.bfloat16 if (device.type=="cuda" and torch.cuda.is_bf16_supported()) else torch.float32 | |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, dtype=dtype).to(device).eval() | |
| # ----------------------------- | |
| # SENTENCE SPLITTER (no lookbehinds) | |
| # Protect → split → restore | |
| # ----------------------------- | |
| ABBR = [ | |
| "e.g", "i.e", "mr", "mrs", "ms", "dr", "prof", "vs", "etc", "fig", "al", | |
| "jr", "sr", "st", "no", "vol", "pp", "mt", "inc", "ltd", "co", "u.s", "u.k", | |
| "a.m", "p.m" | |
| ] | |
| ABBR_REGEX = re.compile(r"\b(" + "|".join(map(re.escape, ABBR)) + r")\.", flags=re.IGNORECASE) | |
| def _protect(text: str) -> str: | |
| t = text.strip() | |
| if not t: | |
| return "" | |
| # Normalize newlines to spaces (Turnitin-like continuous flow) | |
| t = re.sub(r"\s*\n+\s*", " ", t) | |
| # Protect ellipses | |
| t = t.replace("...", "⟨ELLIPSIS⟩") | |
| # Protect decimals like 3.14 | |
| t = re.sub(r"(?<=\d)\.(?=\d)", "⟨DECIMAL⟩", t) | |
| # Protect known abbreviations' final dot | |
| t = ABBR_REGEX.sub(r"\1⟨ABBRDOT⟩", t) | |
| return t | |
| def _restore(text: str) -> str: | |
| return (text | |
| .replace("⟨ABBRDOT⟩", ".") | |
| .replace("⟨DECIMAL⟩", ".") | |
| .replace("⟨ELLIPSIS⟩", "...")) | |
| def sentence_split(text: str): | |
| t = _protect(text) | |
| if not t: | |
| return [] | |
| # Split on ., ?, ! followed by whitespace and then a plausible sentence starter | |
| # (quote or capital or opening paren) OR end of string. | |
| parts = re.split(r"([.?!])\s+(?=(?:[\"“”‘’']?\s*[A-Z(])|$)", t) | |
| # Rebuild sentences: regex split keeps the delimiter in alternating groups | |
| sentences = [] | |
| buf = "" | |
| for i, chunk in enumerate(parts): | |
| if i % 2 == 0: | |
| buf += chunk | |
| else: | |
| # chunk is the delimiter [.?!] | |
| buf += chunk | |
| sentences.append(buf.strip()) | |
| buf = "" | |
| if buf.strip(): | |
| sentences.append(buf.strip()) | |
| # Clean/restore | |
| sentences = [_restore(s).strip() for s in sentences if s.strip()] | |
| return sentences | |
| # ----------------------------- | |
| # UTILITIES | |
| # ----------------------------- | |
| def batched(iterable, n=64): | |
| for i in range(0, len(iterable), n): | |
| yield iterable[i:i+n], i | |
| def contig_spans(labels): | |
| longest = 0 | |
| count = 0 | |
| run = 0 | |
| for lab in labels: | |
| if lab == "AI": | |
| run += 1 | |
| longest = max(longest, run) | |
| else: | |
| if run > 0: | |
| count += 1 | |
| run = 0 | |
| if run > 0: | |
| count += 1 | |
| return count, longest | |
| def verdict_from_stats(flag_pct, longest_span, avg_ai_prob): | |
| if flag_pct >= 85 and longest_span >= 6 and avg_ai_prob >= 0.80: | |
| return "⚠️ Highly likely AI-generated (long consecutive spans and high prevalence)." | |
| if flag_pct >= 60 and longest_span >= 4: | |
| return "⚠️ Strong AI signals (multiple/long spans)." | |
| if flag_pct >= 30 or longest_span >= 3: | |
| return "△ Some AI indicators (partial/short spans)." | |
| return "✓ No clear AI indication (by this detector)." | |
| # ----------------------------- | |
| # CORE CLASSIFIER | |
| # ----------------------------- | |
| def classify_sentences(text, ai_threshold=0.70, batch_size=64, max_len=512): | |
| sents = sentence_split(text) | |
| if not sents: | |
| return [], [], 0.0, 0.0, (0, 0) | |
| all_probs = [] | |
| for chunk, _ in batched(sents, n=batch_size): | |
| inputs = tokenizer( | |
| chunk, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=max_len | |
| ).to(device) | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| probs = F.softmax(logits, dim=-1) # [:,0]=Human, [:,1]=AI | |
| all_probs.extend(probs[:, 1].detach().cpu().tolist()) | |
| labels = ["AI" if p >= ai_threshold else "Human" for p in all_probs] | |
| avg_ai_prob = float(sum(all_probs) / len(all_probs)) | |
| flagged_pct = 100.0 * sum(1 for l in labels if l == "AI") / len(labels) | |
| spans = contig_spans(labels) | |
| rows = [] | |
| for i, (s, p, lab) in enumerate(zip(sents, all_probs, labels), start=1): | |
| rows.append({ | |
| "Sentence #": i, | |
| "Sentence": s, | |
| "AI Probability": round(p, 4), | |
| "Label": lab | |
| }) | |
| return sents, rows, avg_ai_prob, flagged_pct, spans | |
| # ----------------------------- | |
| # HTML HIGHLIGHT | |
| # ----------------------------- | |
| def color_for_prob(p): | |
| if p < 0.30: return "#11823b" # green | |
| if p < 0.70: return "#b8860b" # amber | |
| return "#b80d0d" # red | |
| def build_highlight_html(rows): | |
| blocks = [] | |
| for r in rows: | |
| p = r["AI Probability"] | |
| col = color_for_prob(p) | |
| pct = f"{p*100:.1f}%" | |
| text = re.sub(r"\s+", " ", r["Sentence"]).strip() | |
| blocks.append( | |
| f"<span style='background:rgba(0,0,0,0.02); " | |
| f"padding:4px 6px; border-radius:6px; display:block; margin:6px 0;'>" | |
| f"<strong style='color:{col}'>[{pct} {r['Label']}]</strong> {text}</span>" | |
| ) | |
| return "\n".join(blocks) | |
| # ----------------------------- | |
| # PUBLIC API FOR GRADIO | |
| # ----------------------------- | |
| def generate_report(text, threshold): | |
| if not text or not text.strip(): | |
| return "⚠️ Please enter some text.", None, None, None | |
| sents, rows, avg_ai_prob, flagged_pct, (span_count, longest_span) = classify_sentences( | |
| text, ai_threshold=threshold | |
| ) | |
| verdict = verdict_from_stats(flagged_pct, longest_span, avg_ai_prob) | |
| overall = ( | |
| f"⚖️ Turnitin-style Summary\n" | |
| f"- Overall AI probability (avg per sentence): {avg_ai_prob*100:.1f}%\n" | |
| f"- Sentences flagged as AI ≥ {int(threshold*100)}%: {flagged_pct:.1f}%\n" | |
| f"- Consecutive AI spans: {span_count} (longest: {longest_span})\n" | |
| f"- Verdict: {verdict}\n" | |
| f"\nⓘ This is an approximation using an open detector; actual Turnitin results may differ." | |
| ) | |
| html = build_highlight_html(rows) | |
| df = pd.DataFrame(rows, columns=["Sentence #", "Sentence", "AI Probability", "Label"]) | |
| return overall, html, df, f"{flagged_pct:.1f}%" | |
| # ----------------------------- | |
| # GRADIO UI | |
| # ----------------------------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## 🧭 Writenix AI Detector — Turnitin-style (Sentence-Level)") | |
| with gr.Row(): | |
| text_input = gr.Textbox( | |
| label="Paste your content", | |
| lines=16, | |
| placeholder="Drop your essay/article here…" | |
| ) | |
| with gr.Row(): | |
| threshold = gr.Slider( | |
| 0.50, 0.95, value=0.70, step=0.01, | |
| label="AI Flag Threshold (probability ≥ threshold ⇒ AI)" | |
| ) | |
| detect_btn = gr.Button("🔎 Analyze") | |
| with gr.Row(): | |
| ai_summary = gr.Textbox(label="Report Summary", lines=8) | |
| flagged_pct = gr.Label(label="% Sentences Flagged (AI)") | |
| highlighted = gr.HTML(label="Per-Sentence Highlights") | |
| table = gr.Dataframe(headers=["Sentence #", "Sentence", "AI Probability", "Label"], wrap=True) | |
| detect_btn.click( | |
| fn=generate_report, | |
| inputs=[text_input, threshold], | |
| outputs=[ai_summary, highlighted, table, flagged_pct] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |