VictorM-Coder commited on
Commit
ea2c6a2
·
verified ·
1 Parent(s): 6d5b664

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +160 -0
app.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+ import re
5
+ import pandas as pd
6
+ import gradio as gr
7
+
8
+ # -----------------------------
9
+ # MODEL INITIALIZATION
10
+ # -----------------------------
11
+ MODEL_NAME = "fakespot-ai/roberta-base-ai-text-detection-v1"
12
+
13
+ tokenizer = None
14
+ model = None
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+
17
+ def get_model():
18
+ global tokenizer, model
19
+ if model is None:
20
+ # Loading messages to help debug in HF logs
21
+ print(f"Loading model: {MODEL_NAME} on {device}")
22
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
23
+
24
+ # Use float32 on CPU to prevent build-time precision hangs
25
+ dtype = torch.float32
26
+ if device.type == "cuda" and torch.cuda.is_bf16_supported():
27
+ dtype = torch.bfloat16
28
+
29
+ model = AutoModelForSequenceClassification.from_pretrained(
30
+ MODEL_NAME, torch_dtype=dtype
31
+ ).to(device).eval()
32
+ return tokenizer, model
33
+
34
+ THRESHOLD = 0.20
35
+
36
+ # -----------------------------
37
+ # PROTECT STRUCTURE
38
+ # -----------------------------
39
+ ABBR = ["e.g", "i.e", "mr", "mrs", "ms", "dr", "prof", "vs", "etc", "fig", "al", "jr", "sr", "st", "inc", "ltd", "u.s", "u.k"]
40
+ ABBR_REGEX = re.compile(r"\b(" + "|".join(map(re.escape, ABBR)) + r")\.", re.IGNORECASE)
41
+
42
+ def _protect(text):
43
+ text = text.replace("...", "⟨ELLIPSIS⟩")
44
+ text = re.sub(r"(?<=\d)\.(?=\d)", "⟨DECIMAL⟩", text)
45
+ text = ABBR_REGEX.sub(r"\1⟨ABBRDOT⟩", text)
46
+ return text
47
+
48
+ def _restore(text):
49
+ return text.replace("⟨ABBRDOT⟩", ".").replace("⟨DECIMAL⟩", ".").replace("⟨ELLIPSIS⟩", "...")
50
+
51
+ def split_preserving_structure(text):
52
+ blocks = re.split(r"(\n+)", text)
53
+ final_blocks = []
54
+ for block in blocks:
55
+ if block.startswith("\n"):
56
+ final_blocks.append(block)
57
+ else:
58
+ protected = _protect(block)
59
+ parts = re.split(r"([.?!])(\s+)", protected)
60
+ for i in range(0, len(parts), 3):
61
+ sentence = parts[i]
62
+ punct = parts[i+1] if i+1 < len(parts) else ""
63
+ space = parts[i+2] if i+2 < len(parts) else ""
64
+ if sentence.strip():
65
+ final_blocks.append(_restore(sentence + punct))
66
+ if space:
67
+ final_blocks.append(space)
68
+ return final_blocks
69
+
70
+ # -----------------------------
71
+ # ANALYSIS
72
+ # -----------------------------
73
+ @torch.inference_mode()
74
+ def analyze(text):
75
+ if not text.strip():
76
+ return "—", "—", "<em>Please enter text...</em>", None
77
+ try:
78
+ tok, mod = get_model()
79
+ except Exception as e:
80
+ return "ERROR", "0%", f"Failed to load model: {str(e)}", None
81
+
82
+ blocks = split_preserving_structure(text)
83
+ pure_sents_indices = [i for i, b in enumerate(blocks) if b.strip() and not b.startswith("\n")]
84
+ pure_sents = [blocks[i] for i in pure_sents_indices]
85
+
86
+ if not pure_sents:
87
+ return "—", "—", "<em>No sentences detected.</em>", None
88
+
89
+ windows = []
90
+ for i in range(len(pure_sents)):
91
+ start = max(0, i - 1)
92
+ end = min(len(pure_sents), i + 2)
93
+ windows.append(" ".join(pure_sents[start:end]))
94
+
95
+ inputs = tok(windows, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
96
+ logits = mod(**inputs).logits
97
+ probs = F.softmax(logits.float(), dim=-1)[:, 1].cpu().numpy().tolist()
98
+
99
+ lengths = [len(s.split()) for s in pure_sents]
100
+ total_words = sum(lengths)
101
+ weighted_avg = sum(p * l for p, l in zip(probs, lengths)) / total_words if total_words > 0 else 0
102
+
103
+ # -----------------------------
104
+ # HTML RECONSTRUCTION
105
+ # -----------------------------
106
+ highlighted_html = "<div style='font-family: sans-serif; line-height: 1.8;'>"
107
+ prob_map = {idx: probs[i] for i, idx in enumerate(pure_sents_indices)}
108
+
109
+ for i, block in enumerate(blocks):
110
+ if block.startswith("\n") or block.isspace():
111
+ highlighted_html += block.replace("\n", "<br>")
112
+ continue
113
+
114
+ if i in prob_map:
115
+ score = prob_map[i]
116
+ if score < 0.35: color, bg = "#11823b", "rgba(17, 130, 59, 0.15)"
117
+ elif score < 0.70: color, bg = "#b8860b", "rgba(184, 134, 11, 0.15)"
118
+ else: color, bg = "#b80d0d", "rgba(184, 13, 13, 0.15)"
119
+
120
+ highlighted_html += (
121
+ f"<span style='background:{bg}; padding:2px 4px; border-radius:4px; border-bottom: 2px solid {color};' "
122
+ f"title='AI Probability: {score:.1%}'>"
123
+ f"<b style='color:{color}; font-size: 0.8em;'>[{score:.0%}]</b> {block}</span>"
124
+ )
125
+ else:
126
+ highlighted_html += block
127
+
128
+ highlighted_html += "</div>"
129
+
130
+ # Updated Label Logic per your request
131
+ label = "AI Content Detected" if weighted_avg >= THRESHOLD else "0 or * AI Content Detected"
132
+
133
+ df = pd.DataFrame({"Sentence": pure_sents, "AI Confidence": [f"{p:.1%}" for p in probs]})
134
+ return label, f"{weighted_avg:.1%}", highlighted_html, df
135
+
136
+ # -----------------------------
137
+ # GRADIO INTERFACE
138
+ # -----------------------------
139
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
140
+ gr.Markdown("## 🕵️ AI Detector Pro")
141
+ gr.Markdown("Sentence-level analysis with weighted context windows.")
142
+
143
+ with gr.Row():
144
+ with gr.Column(scale=3):
145
+ text_input = gr.Textbox(label="Paste Text", lines=12)
146
+ run_btn = gr.Button("Analyze", variant="primary")
147
+ with gr.Column(scale=1):
148
+ verdict_out = gr.Label(label="Verdict")
149
+ score_out = gr.Label(label="Weighted AI Score")
150
+
151
+ with gr.Tabs():
152
+ with gr.TabItem("Visual Heatmap"):
153
+ html_out = gr.HTML()
154
+ with gr.TabItem("Raw Data Breakdown"):
155
+ table_out = gr.Dataframe(headers=["Sentence", "AI Confidence"], wrap=True)
156
+
157
+ run_btn.click(analyze, inputs=text_input, outputs=[verdict_out, score_out, html_out, table_out])
158
+
159
+ if __name__ == "__main__":
160
+ demo.launch()