mahmoudsaber0 commited on
Commit
dde6bd9
·
verified ·
1 Parent(s): 080d131

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -68
app.py CHANGED
@@ -1,13 +1,12 @@
1
  import os
2
- from fastapi import FastAPI, WebSocket, UploadFile, File
3
- from fastapi.middleware.cors import CORSMiddleware
4
- from fastapi.responses import JSONResponse
5
  import torch
6
- from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
7
- import asyncio
 
8
 
9
  # =====================================================
10
- # ✅ Fix Hugging Face Cache Permission Errors
11
  # =====================================================
12
  CACHE_DIR = "/tmp/hf_cache"
13
  os.environ["HF_HOME"] = CACHE_DIR
@@ -17,80 +16,102 @@ os.environ["HF_HUB_CACHE"] = CACHE_DIR
17
  os.makedirs(CACHE_DIR, exist_ok=True)
18
 
19
  # =====================================================
20
- # ✅ Initialize Model and Tokenizer
21
  # =====================================================
22
- MODEL_NAME = "answerdotai/ModernBERT-base"
23
 
24
- print("Loading model and tokenizer...")
25
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
26
- model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
27
 
28
- classifier = pipeline(
29
- "text-classification",
30
- model=model,
31
- tokenizer=tokenizer,
32
- device=0 if torch.cuda.is_available() else -1
33
- )
34
 
35
  # =====================================================
36
- # ✅ FastAPI App Setup
37
  # =====================================================
38
- app = FastAPI(title="ModernBERT FastAPI Server")
39
-
40
- # Allow all origins (for testing)
41
- app.add_middleware(
42
- CORSMiddleware,
43
- allow_origins=["*"],
44
- allow_credentials=True,
45
- allow_methods=["*"],
46
- allow_headers=["*"],
47
- )
48
 
49
  # =====================================================
50
- # ✅ REST Endpoint Example
51
  # =====================================================
52
- @app.post("/analyze")
53
- async def analyze_text(data: dict):
54
- try:
55
- text = data.get("text", "")
56
- if not text.strip():
57
- return JSONResponse({"error": "Empty text provided"}, status_code=400)
58
 
59
- result = classifier(text)
60
- return {"result": result}
61
- except Exception as e:
62
- return JSONResponse({"error": str(e)}, status_code=500)
 
 
63
 
64
- # =====================================================
65
- # ✅ WebSocket Endpoint (real-time classification)
66
- # =====================================================
67
- @app.websocket("/ws")
68
- async def websocket_endpoint(ws: WebSocket):
69
- await ws.accept()
70
- idle_timeout = 60 # seconds
71
-
72
- async def close_if_idle():
73
- while True:
74
- await asyncio.sleep(idle_timeout)
75
- await ws.close(code=1000)
76
- break
77
-
78
- asyncio.create_task(close_if_idle())
79
-
80
- try:
81
- while True:
82
- message = await ws.receive_text()
83
- if message.lower() in ["exit", "quit"]:
84
- await ws.close(code=1000)
85
- break
86
- result = classifier(message)
87
- await ws.send_json(result)
88
- except Exception:
89
- await ws.close()
90
 
91
  # =====================================================
92
- # ✅ Root Endpoint
93
  # =====================================================
94
  @app.get("/")
95
- def home():
96
- return {"status": "ok", "model": MODEL_NAME, "device": "cuda" if torch.cuda.is_available() else "cpu"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import re
 
 
3
  import torch
4
+ from fastapi import FastAPI
5
+ from pydantic import BaseModel
6
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
7
 
8
  # =====================================================
9
+ # ✅ Safe Hugging Face Cache Configuration
10
  # =====================================================
11
  CACHE_DIR = "/tmp/hf_cache"
12
  os.environ["HF_HOME"] = CACHE_DIR
 
16
  os.makedirs(CACHE_DIR, exist_ok=True)
17
 
18
  # =====================================================
19
+ # ✅ Load Model and Tokenizer
20
  # =====================================================
21
+ MODEL_NAME = "roberta-base-openai-detector"
22
 
23
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
24
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
 
25
 
26
+ app = FastAPI(title="AI Text Detector")
 
 
 
 
 
27
 
28
  # =====================================================
29
+ # ✅ Input Schema
30
  # =====================================================
31
+ class InputText(BaseModel):
32
+ text: str
 
 
 
 
 
 
 
 
33
 
34
  # =====================================================
35
+ # ✅ Helper Functions
36
  # =====================================================
37
+ def split_into_paragraphs(text: str):
38
+ """Split text into paragraphs by double newlines or long single breaks."""
39
+ paragraphs = re.split(r'\n\s*\n', text.strip())
40
+ paragraphs = [p.strip() for p in paragraphs if len(p.strip()) > 0]
41
+ return paragraphs
 
42
 
43
+ def analyze_text_block(text: str):
44
+ """Analyze a single paragraph and return AI/Human probability."""
45
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
46
+ with torch.no_grad():
47
+ logits = model(**inputs).logits
48
+ probs = torch.softmax(logits, dim=1)[0].tolist()
49
 
50
+ return {
51
+ "label_scores": {
52
+ model.config.id2label[0]: round(probs[0], 4),
53
+ model.config.id2label[1]: round(probs[1], 4)
54
+ },
55
+ "ai_generated_score": probs[1],
56
+ "human_written_score": probs[0],
57
+ "is_ai": probs[1] > probs[0]
58
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  # =====================================================
61
+ # ✅ Routes
62
  # =====================================================
63
  @app.get("/")
64
+ def root():
65
+ return {"message": "AI Text Detector is running. Use POST /analyze with {'text': 'your text'}"}
66
+
67
+ @app.post("/analyze")
68
+ async def analyze(data: InputText):
69
+ text = data.text.strip()
70
+ if not text:
71
+ return {"success": False, "code": 400, "message": "Empty input text"}
72
+
73
+ paragraphs = split_into_paragraphs(text)
74
+ results = []
75
+
76
+ ai_words, total_words = 0, 0
77
+
78
+ for paragraph in paragraphs:
79
+ res = analyze_text_block(paragraph)
80
+ results.append({
81
+ "paragraph": paragraph,
82
+ "ai_generated_score": res["ai_generated_score"],
83
+ "human_written_score": res["human_written_score"]
84
+ })
85
+
86
+ word_count = len(paragraph.split())
87
+ total_words += word_count
88
+ ai_words += word_count * res["ai_generated_score"]
89
+
90
+ fake_percentage = round((ai_words / total_words) * 100, 2) if total_words > 0 else 0
91
+ feedback = (
92
+ "Most of Your Text is AI/GPT Generated"
93
+ if fake_percentage > 50
94
+ else "Most of Your Text Appears Human-Written"
95
+ )
96
+
97
+ return {
98
+ "success": True,
99
+ "code": 200,
100
+ "message": "detection result passed to proxy",
101
+ "data": {
102
+ "sentences": [],
103
+ "isHuman": round(100 - fake_percentage, 2),
104
+ "additional_feedback": "",
105
+ "h": [r["paragraph"] for r in results],
106
+ "hi": [],
107
+ "textWords": total_words,
108
+ "aiWords": int(total_words * (fake_percentage / 100)),
109
+ "fakePercentage": fake_percentage,
110
+ "specialIndexes": [],
111
+ "specialSentences": [],
112
+ "originalParagraph": text,
113
+ "feedback": feedback,
114
+ "input_text": text,
115
+ "detected_language": "en"
116
+ }
117
+ }