KavinduHansaka commited on
Commit
f1f7273
·
verified ·
1 Parent(s): 31c512a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -52
app.py CHANGED
@@ -1,74 +1,114 @@
1
  import gradio as gr
2
  import pandas as pd
3
- from detoxify import Detoxify
4
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
  import torch
 
 
6
 
7
- # Load models once
8
- tox_model = Detoxify('multilingual')
9
- ai_tokenizer = AutoTokenizer.from_pretrained("openai-community/roberta-base-openai-detector")
10
- ai_model = AutoModelForSequenceClassification.from_pretrained("openai-community/roberta-base-openai-detector")
11
-
12
- # Thresholds
13
- TOXICITY_THRESHOLD = 0.7
14
  AI_THRESHOLD = 0.5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- def detect_ai(text):
17
- with torch.no_grad():
18
- inputs = ai_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
19
- logits = ai_model(**inputs).logits
20
- probs = torch.softmax(logits, dim=1).squeeze().tolist()
21
- return round(probs[1], 4) # Probability of AI-generated
22
-
23
- def classify_comments(comment_list):
24
- results = tox_model.predict(comment_list)
25
- df = pd.DataFrame(results, index=comment_list).round(4)
26
- df.columns = [col.replace("_", " ").title().replace(" ", "_") for col in df.columns]
27
- df.columns = [col.replace("_", " ") for col in df.columns]
28
- df["⚠️ Warning"] = df.apply(
29
- lambda row: "⚠️ High Risk" if any(score > TOXICITY_THRESHOLD for score in row) else "✅ Safe",
30
- axis=1
31
- )
32
- df["🧪 AI Probability"] = [detect_ai(c) for c in df.index]
33
- df["🧪 AI Detection"] = df["🧪 AI Probability"].apply(
34
- lambda x: "🤖 Likely AI" if x > AI_THRESHOLD else "🧍 Human"
35
- )
36
  return df
37
 
38
- def run_classification(text_input, csv_file):
39
- comment_list = []
 
 
 
 
40
 
41
  if text_input.strip():
42
- comment_list += [c.strip() for c in text_input.strip().split('\n') if c.strip()]
43
 
44
  if csv_file:
45
  df = pd.read_csv(csv_file.name)
46
- if 'comment' not in df.columns:
47
- return "CSV must contain a 'comment' column.", None
48
- comment_list += df['comment'].astype(str).tolist()
49
 
50
- if not comment_list:
51
- return "Please provide comments via text or CSV.", None
52
 
53
- df = classify_comments(comment_list)
54
- csv_data = df.copy()
55
- csv_data.insert(0, "Comment", df.index)
56
- return df, ("toxicity_predictions.csv", csv_data.to_csv(index=False).encode())
57
 
58
- # Build the Gradio UI
59
- with gr.Blocks(title="🌍 Toxic Comment & AI Detector") as app:
60
- gr.Markdown("## 🌍 Toxic Comment & AI Detector")
61
- gr.Markdown("Detects multilingual toxicity and whether a comment is AI-generated. Paste comments or upload a CSV.")
 
 
 
 
 
 
62
 
63
  with gr.Row():
64
- text_input = gr.Textbox(lines=8, label="💬 Paste Comments (one per line)")
65
- csv_input = gr.File(label="📥 Upload CSV (must have 'comment' column)")
 
 
 
 
 
 
 
66
 
67
- submit_button = gr.Button("🔍 Analyze Comments")
68
- output_table = gr.Dataframe(label="📊 Prediction Results")
69
- download_button = gr.File(label="📤 Download CSV")
70
 
71
- submit_button.click(fn=run_classification, inputs=[text_input, csv_input], outputs=[output_table, download_button])
 
 
 
 
72
 
73
  if __name__ == "__main__":
74
- app.launch()
 
1
  import gradio as gr
2
  import pandas as pd
 
 
3
  import torch
4
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
+ from typing import List, Tuple
6
 
7
+ # =========================
8
+ # Configuration
9
+ # =========================
10
+ MODEL_NAME = "openai-community/roberta-base-openai-detector"
 
 
 
11
  AI_THRESHOLD = 0.5
12
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13
+
14
+ # =========================
15
+ # Model Loading (once)
16
+ # =========================
17
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
18
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
19
+ model.to(DEVICE)
20
+ model.eval()
21
+
22
+ # =========================
23
+ # Core Logic
24
+ # =========================
25
+ @torch.no_grad()
26
+ def detect_ai_probability(texts: List[str]) -> List[float]:
27
+ """
28
+ Returns probability that each text is AI-generated.
29
+ """
30
+ inputs = tokenizer(
31
+ texts,
32
+ return_tensors="pt",
33
+ padding=True,
34
+ truncation=True,
35
+ max_length=512
36
+ ).to(DEVICE)
37
+
38
+ logits = model(**inputs).logits
39
+ probs = torch.softmax(logits, dim=1)[:, 1] # AI-generated class
40
+ return probs.cpu().tolist()
41
+
42
+
43
+ def classify_texts(texts: List[str]) -> pd.DataFrame:
44
+ """
45
+ Classify texts as AI or Human.
46
+ """
47
+ probabilities = detect_ai_probability(texts)
48
+
49
+ df = pd.DataFrame({
50
+ "Comment": texts,
51
+ "AI Probability": [round(p, 4) for p in probabilities],
52
+ "Prediction": [
53
+ "🤖 Likely AI" if p >= AI_THRESHOLD else "🧍 Human"
54
+ for p in probabilities
55
+ ]
56
+ })
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  return df
59
 
60
+
61
+ def run_detector(text_input: str, csv_file) -> Tuple[pd.DataFrame, Tuple[str, bytes]]:
62
+ """
63
+ Handles UI input and output.
64
+ """
65
+ texts: List[str] = []
66
 
67
  if text_input.strip():
68
+ texts.extend([t.strip() for t in text_input.split("\n") if t.strip()])
69
 
70
  if csv_file:
71
  df = pd.read_csv(csv_file.name)
72
+ if "comment" not in df.columns:
73
+ return pd.DataFrame({"Error": ["CSV must contain a 'comment' column"]}), None
74
+ texts.extend(df["comment"].astype(str).tolist())
75
 
76
+ if not texts:
77
+ return pd.DataFrame({"Error": ["No input provided"]}), None
78
 
79
+ result_df = classify_texts(texts)
 
 
 
80
 
81
+ csv_bytes = result_df.to_csv(index=False).encode("utf-8")
82
+ return result_df, ("ai_detection_results.csv", csv_bytes)
83
+
84
+
85
+ # =========================
86
+ # Gradio UI
87
+ # =========================
88
+ with gr.Blocks(title="🧪 AI Text Detector") as app:
89
+ gr.Markdown("## 🧪 AI Text Detector")
90
+ gr.Markdown("Detect whether text is **AI-generated or human-written**.")
91
 
92
  with gr.Row():
93
+ text_input = gr.Textbox(
94
+ lines=8,
95
+ label="✍️ Paste Text (one per line)",
96
+ placeholder="Enter multiple comments, one per line..."
97
+ )
98
+ csv_input = gr.File(
99
+ label="📄 Upload CSV",
100
+ file_types=[".csv"]
101
+ )
102
 
103
+ analyze_btn = gr.Button("🔍 Analyze")
104
+ output_table = gr.Dataframe(label="📊 Results")
105
+ download_file = gr.File(label="⬇️ Download CSV")
106
 
107
+ analyze_btn.click(
108
+ fn=run_detector,
109
+ inputs=[text_input, csv_input],
110
+ outputs=[output_table, download_file]
111
+ )
112
 
113
  if __name__ == "__main__":
114
+ app.launch()