mahmoudsaber0 commited on
Commit
a03c764
·
verified ·
1 Parent(s): 9cf13c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -79
app.py CHANGED
@@ -1,99 +1,78 @@
1
- import re
2
- import torch
3
  from fastapi import FastAPI, Request
4
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
  from pydantic import BaseModel
6
- from typing import List
7
- import uvicorn
8
-
9
- # ========== CONFIG ==========
10
- MODEL_PATH = "roberta-base-openai-detector" # or your preferred detector
11
- device = "cuda" if torch.cuda.is_available() else "cpu"
12
-
13
- tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
14
- model_1 = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH).to(device)
15
- model_2 = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH).to(device)
16
- model_3 = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH).to(device)
17
 
18
- label_mapping = {
19
- 0: "gpt2", 1: "gpt3", 2: "gpt4", 3: "chatgpt", 4: "dolly", 5: "human", 24: "human"
20
- }
21
 
22
- app = FastAPI(title="AI Text Classifier API", version="1.0.0")
 
23
 
 
 
 
 
 
24
 
25
- # ========== HELPERS ==========
26
  def clean_text(text: str) -> str:
27
- text = re.sub(r'\s+', ' ', text)
 
28
  return text.strip()
29
 
 
 
 
30
 
31
- # ========== INPUT MODEL ==========
32
- class TextInput(BaseModel):
33
- text: str
 
34
 
 
 
35
 
36
- # ========== MAIN LOGIC ==========
37
- @app.post("/analyze")
38
- async def analyze_text(data: TextInput):
39
- cleaned_text = clean_text(data.text)
40
- if not cleaned_text.strip():
41
- return {"success": False, "error": "Empty text provided"}
42
-
43
- paragraphs = [p.strip() for p in re.split(r'\n{2,}', cleaned_text) if p.strip()]
44
- if not paragraphs:
45
- paragraphs = [cleaned_text]
46
-
47
- chunk_scores = []
48
- all_probs = []
49
-
50
- for paragraph in paragraphs:
51
- inputs = tokenizer(paragraph, return_tensors="pt", truncation=True, padding=True).to(device)
52
  with torch.no_grad():
53
- logits_1 = model_1(**inputs).logits
54
- logits_2 = model_2(**inputs).logits
55
- logits_3 = model_3(**inputs).logits
56
- softmax_1 = torch.softmax(logits_1, dim=1)
57
- softmax_2 = torch.softmax(logits_2, dim=1)
58
- softmax_3 = torch.softmax(logits_3, dim=1)
59
- averaged = (softmax_1 + softmax_2 + softmax_3) / 3
60
- probs = averaged[0]
61
- all_probs.append(probs.cpu())
62
-
63
- human_prob = probs[24].item() if 24 in label_mapping else probs[-1].item()
64
- ai_probs_clone = probs.clone()
65
- ai_probs_clone[24] = 0
66
- ai_total = ai_probs_clone.sum().item()
67
- total = human_prob + ai_total
68
- human_pct = (human_prob / total) * 100
69
- ai_pct = (ai_total / total) * 100
70
- ai_model = label_mapping[torch.argmax(ai_probs_clone).item()]
71
-
72
- chunk_scores.append({
73
- "human": round(human_pct, 2),
74
- "ai": round(ai_pct, 2),
75
- "model": ai_model,
76
- "text_preview": paragraph[:250].replace('\n', ' ') + ("..." if len(paragraph) > 250 else "")
77
  })
78
 
79
- # ---- OVERALL ----
80
- avg_human = sum(c["human"] for c in chunk_scores) / len(chunk_scores)
81
- avg_ai = sum(c["ai"] for c in chunk_scores) / len(chunk_scores)
82
-
83
- if avg_ai > avg_human:
84
- top_model = max(chunk_scores, key=lambda c: c["ai"])["model"]
85
- overall = {"result": f"{avg_ai:.2f}% AI-generated", "model": top_model}
86
- else:
87
- overall = {"result": f"{avg_human:.2f}% Human-written", "model": "human"}
88
 
89
  return {
90
- "success": True,
91
- "overall": overall,
92
- "paragraphs": chunk_scores,
93
- "total_paragraphs": len(chunk_scores)
 
 
94
  }
95
 
 
 
 
96
 
97
- # ========== RUN LOCALLY ==========
98
- if __name__ == "__main__":
99
- uvicorn.run("app:app", host="0.0.0.0", port=8000)
 
 
 
 
 
 
 
 
1
  from fastapi import FastAPI, Request
 
2
  from pydantic import BaseModel
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+ import torch
5
+ import re
 
 
 
 
 
 
 
 
6
 
7
+ app = FastAPI(title="AI Text Detector API")
 
 
8
 
9
+ # Device setup
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
 
12
+ # Load model (use small model for Hugging Face to prevent restarts)
13
+ MODEL_NAME = "roberta-base-openai-detector"
14
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
15
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME).to(device)
16
+ model.eval()
17
 
18
+ # --- Text Cleaning ---
19
  def clean_text(text: str) -> str:
20
+ text = re.sub(r'\s{2,}', ' ', text)
21
+ text = re.sub(r'\s+([,.;:?!])', r'\1', text)
22
  return text.strip()
23
 
24
+ # --- Paragraph Splitter ---
25
+ def split_paragraphs(text: str):
26
+ return [p.strip() for p in re.split(r'\n{2,}', text) if p.strip()]
27
 
28
+ # --- Classification ---
29
+ def analyze_text(text: str):
30
+ text = clean_text(text)
31
+ paragraphs = split_paragraphs(text)
32
 
33
+ paragraph_results = []
34
+ total_ai, total_human = 0, 0
35
 
36
+ for i, p in enumerate(paragraphs, 1):
37
+ inputs = tokenizer(p, return_tensors="pt", truncation=True, padding=True).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  with torch.no_grad():
39
+ logits = model(**inputs).logits
40
+ probs = torch.softmax(logits, dim=1)[0]
41
+ ai_score = float(probs[1].item() * 100)
42
+ human_score = float(probs[0].item() * 100)
43
+
44
+ total_ai += ai_score
45
+ total_human += human_score
46
+
47
+ paragraph_results.append({
48
+ "paragraph_number": i,
49
+ "ai_probability": round(ai_score, 2),
50
+ "human_probability": round(human_score, 2),
51
+ "text_snippet": p[:150] + ("..." if len(p) > 150 else "")
 
 
 
 
 
 
 
 
 
 
 
52
  })
53
 
54
+ avg_ai = total_ai / len(paragraphs)
55
+ avg_human = total_human / len(paragraphs)
56
+ overall_label = "AI-generated" if avg_ai > avg_human else "Human-written"
 
 
 
 
 
 
57
 
58
  return {
59
+ "overall_result": {
60
+ "ai_percentage": round(avg_ai, 2),
61
+ "human_percentage": round(avg_human, 2),
62
+ "label": overall_label
63
+ },
64
+ "paragraphs": paragraph_results
65
  }
66
 
67
+ # --- Request Schema ---
68
+ class TextInput(BaseModel):
69
+ text: str
70
 
71
+ # --- API Routes ---
72
+ @app.get("/")
73
+ async def root():
74
+ return {"status": "ok", "message": "AI Text Detector API is running."}
75
+
76
+ @app.post("/analyze")
77
+ async def analyze(input_data: TextInput):
78
+ return analyze_text(input_data.text)