mahmoudsaber0 commited on
Commit
0dbeec3
·
verified ·
1 Parent(s): 943ecd3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +211 -107
app.py CHANGED
@@ -1,116 +1,220 @@
1
  import os
 
 
2
  import torch
3
- from fastapi import FastAPI, WebSocket, UploadFile, File
4
- from fastapi.responses import JSONResponse
5
- from fastapi.middleware.cors import CORSMiddleware
6
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
7
- from huggingface_hub import hf_hub_download
8
-
9
- # ==========================================================
10
- # ✅ Environment Setup (Safe Caching)
11
- # ==========================================================
12
- os.environ["HF_HOME"] = "/tmp/huggingface"
13
- os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
14
- os.environ["HF_HUB_CACHE"] = "/tmp/huggingface"
15
- os.makedirs("/tmp/huggingface", exist_ok=True)
16
- os.makedirs("/tmp/models", exist_ok=True)
17
-
18
- # ==========================================================
19
- # ✅ Model Paths
20
- # ==========================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  model1_path = "modernbert.bin"
22
- model2_path = "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed12"
23
- model3_path = "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed22"
24
-
25
- # ==========================================================
26
- # Helper: Load Model (Local or Remote)
27
- # ==========================================================
28
- def load_model(path_or_url, model_name="modernbert"):
29
- try:
30
- if path_or_url.startswith("http"):
31
- print(f"Downloading model from {path_or_url}...")
32
- filename = os.path.join("/tmp/models", os.path.basename(path_or_url))
33
- if not os.path.exists(filename):
34
- torch.hub.download_url_to_file(path_or_url, filename)
35
- path_or_url = filename
36
-
37
- tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
38
- model = AutoModelForSequenceClassification.from_pretrained(
39
- "bert-base-uncased",
40
- state_dict=torch.load(path_or_url, map_location="cpu"),
41
  )
42
- model.eval()
43
- print(f"✅ Loaded model: {path_or_url}")
44
- return tokenizer, model
45
- except Exception as e:
46
- print(f"❌ Failed to load model from {path_or_url}: {e}")
47
- return None, None
48
-
49
- # ==========================================================
50
- # Load All Models
51
- # ==========================================================
52
- tokenizer1, model1 = load_model(model1_path, "model1")
53
- tokenizer2, model2 = load_model(model2_path, "model2")
54
- tokenizer3, model3 = load_model(model3_path, "model3")
55
-
56
- # ==========================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  # ✅ FastAPI Setup
58
- # ==========================================================
59
- app = FastAPI(title="ModernBERT AI Detection API")
60
-
61
- app.add_middleware(
62
- CORSMiddleware,
63
- allow_origins=["*"], allow_credentials=True,
64
- allow_methods=["*"], allow_headers=["*"]
65
- )
66
-
67
- # ==========================================================
68
- # ✅ Root Route
69
- # ==========================================================
70
- @app.get("/")
71
- def home():
72
- return {"message": "ModernBERT API is running"}
73
-
74
- # ==========================================================
75
- # ✅ Text Inference Endpoint
76
- # ==========================================================
77
  @app.post("/analyze")
78
- async def analyze_text(data: dict):
79
- text = data.get("text", "")
80
- if not text.strip():
81
- return JSONResponse({"error": "No text provided"}, status_code=400)
82
 
83
- inputs = tokenizer1(text, return_tensors="pt", truncation=True, padding=True)
84
- results = {}
 
 
 
85
 
86
- with torch.no_grad():
87
- if model1:
88
- results["model1"] = torch.softmax(model1(**inputs).logits, dim=-1).tolist()[0]
89
- if model2:
90
- results["model2"] = torch.softmax(model2(**inputs).logits, dim=-1).tolist()[0]
91
- if model3:
92
- results["model3"] = torch.softmax(model3(**inputs).logits, dim=-1).tolist()[0]
93
-
94
- avg_score = torch.tensor([results[m][0] for m in results]).mean().item()
95
- return {"results": results, "avg_score": avg_score}
96
-
97
- # ==========================================================
98
- # ✅ WebSocket for Real-Time Scoring
99
- # ==========================================================
100
- @app.websocket("/ws")
101
- async def websocket_endpoint(websocket: WebSocket):
102
- await websocket.accept()
103
- try:
104
- while True:
105
- data = await websocket.receive_text()
106
- inputs = tokenizer1(data, return_tensors="pt", truncation=True, padding=True)
107
- with torch.no_grad():
108
- output = torch.softmax(model1(**inputs).logits, dim=-1)
109
- await websocket.send_json({"score": output.tolist()[0]})
110
- except Exception:
111
- await websocket.close()
112
-
113
- # ==========================================================
114
- # Run Command
115
- # ==========================================================
116
- # uvicorn app:app --host 0.0.0.0 --
 
 
 
 
 
 
 
 
1
  import os
2
+ import re
3
+ import shutil
4
  import torch
5
+ from fastapi import FastAPI
6
+ from pydantic import BaseModel
7
+
8
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
9
+ from tokenizers.normalizers import Sequence, Replace, Strip
10
+ from tokenizers import Regex
11
+ import atexit
12
+
13
+ # =====================================================
14
+ # Safe Cache Setup (Runtime)
15
+ # =====================================================
16
+ CACHE_DIR = "/tmp/huggingface"
17
+
18
+ # Clear any old cache to prevent exceeding 50G
19
+ if os.path.exists(CACHE_DIR):
20
+ shutil.rmtree(CACHE_DIR, ignore_errors=True)
21
+ os.makedirs(CACHE_DIR, exist_ok=True)
22
+
23
+ # Set environment paths
24
+ os.environ.update({
25
+ "HF_HOME": CACHE_DIR,
26
+ "TRANSFORMERS_CACHE": CACHE_DIR,
27
+ "HF_DATASETS_CACHE": CACHE_DIR,
28
+ "HF_HUB_CACHE": CACHE_DIR,
29
+ "TORCH_HOME": CACHE_DIR,
30
+ "XDG_CACHE_HOME": CACHE_DIR,
31
+ "TORCHINDUCTOR_CACHE_DIR": CACHE_DIR,
32
+ "TORCH_LOGS": "off"
33
+ })
34
+
35
+ # Auto cleanup on shutdown
36
+ @atexit.register
37
+ def cleanup_cache():
38
+ shutil.rmtree(CACHE_DIR, ignore_errors=True)
39
+
40
+ # =====================================================
41
+ # ✅ Model Setup
42
+ # =====================================================
43
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
44
+
45
+ tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")
46
+
47
+ # --- Model paths ---
48
  model1_path = "modernbert.bin"
49
+ model2_url = "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed12"
50
+ model3_url = "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed22"
51
+
52
+ def load_model(base_path=None, url=None):
53
+ model = AutoModelForSequenceClassification.from_pretrained(
54
+ "answerdotai/ModernBERT-base", num_labels=41
55
+ )
56
+ if url:
57
+ state_dict = torch.hub.load_state_dict_from_url(
58
+ url, map_location=device, progress=False, check_hash=False
59
+
60
+
61
+
62
+
63
+
64
+
65
+
66
+
67
+
68
  )
69
+ else:
70
+ state_dict = torch.load(base_path, map_location=device)
71
+ model.load_state_dict(state_dict)
72
+ model.to(device).eval()
73
+ return model
74
+
75
+ model_1 = load_model(model1_path)
76
+ model_2 = load_model(url=model2_url)
77
+ model_3 = load_model(url=model3_url)
78
+
79
+ # =====================================================
80
+ # Label Mapping
81
+ # =====================================================
82
+ label_mapping = {
83
+ 0: '13B', 1: '30B', 2: '65B', 3: '7B', 4: 'GLM130B', 5: 'bloom_7b',
84
+ 6: 'bloomz', 7: 'cohere', 8: 'davinci', 9: 'dolly', 10: 'dolly-v2-12b',
85
+ 11: 'flan_t5_base', 12: 'flan_t5_large', 13: 'flan_t5_small',
86
+ 14: 'flan_t5_xl', 15: 'flan_t5_xxl', 16: 'gemma-7b-it', 17: 'gemma2-9b-it',
87
+ 18: 'gpt-3.5-turbo', 19: 'gpt-35', 20: 'gpt4', 21: 'gpt4o',
88
+ 22: 'gpt_j', 23: 'gpt_neox', 24: 'human', 25: 'llama3-70b', 26: 'llama3-8b',
89
+ 27: 'mixtral-8x7b', 28: 'opt_1.3b', 29: 'opt_125m', 30: 'opt_13b',
90
+ 31: 'opt_2.7b', 32: 'opt_30b', 33: 'opt_350m', 34: 'opt_6.7b',
91
+ 35: 'opt_iml_30b', 36: 'opt_iml_max_1.3b', 37: 't0_11b', 38: 't0_3b',
92
+ 39: 'text-davinci-002', 40: 'text-davinci-003'
93
+ }
94
+
95
+ # =====================================================
96
+ # ✅ Text Cleaning & Tokenizer Normalization
97
+ # =====================================================
98
+ def clean_text(text: str) -> str:
99
+ text = re.sub(r'\s{2,}', ' ', text)
100
+ text = re.sub(r'\s+([,.;:?!])', r'\1', text)
101
+ return text
102
+
103
+ newline_to_space = Replace(Regex(r'\s*\n\s*'), " ")
104
+ join_hyphen_break = Replace(Regex(r'(\w+)[--]\s*\n\s*(\w+)'), r"\1\2")
105
+ tokenizer.backend_tokenizer.normalizer = Sequence([
106
+ tokenizer.backend_tokenizer.normalizer,
107
+ join_hyphen_break,
108
+ newline_to_space,
109
+ Strip()
110
+ ])
111
+
112
+ # =====================================================
113
+ # ✅ Analysis Logic
114
+ # =====================================================
115
+ def analyze_text_block(text: str):
116
+ cleaned_text = clean_text(text)
117
+ inputs = tokenizer(cleaned_text, return_tensors="pt", truncation=True, padding=True).to(device)
118
+
119
+ with torch.no_grad():
120
+ logits_1 = model_1(**inputs).logits
121
+ logits_2 = model_2(**inputs).logits
122
+ logits_3 = model_3(**inputs).logits
123
+
124
+ avg_probs = (
125
+ torch.softmax(logits_1, dim=1) +
126
+ torch.softmax(logits_2, dim=1) +
127
+ torch.softmax(logits_3, dim=1)
128
+ ) / 3
129
+
130
+ probs = avg_probs[0]
131
+ human_prob = probs[24].item()
132
+ ai_probs = probs.clone()
133
+ ai_probs[24] = 0
134
+ ai_total_prob = ai_probs.sum().item()
135
+
136
+ total = human_prob + ai_total_prob
137
+ human_percentage = (human_prob / total) * 100
138
+ ai_percentage = (ai_total_prob / total) * 100
139
+ ai_model_index = torch.argmax(ai_probs).item()
140
+
141
+ return {
142
+ "human_written_score": round(human_percentage / 100, 4),
143
+ "ai_generated_score": round(ai_percentage / 100, 4),
144
+ "predicted_model": label_mapping[ai_model_index]
145
+ }
146
+
147
+ def split_into_paragraphs(text: str):
148
+ return [p.strip() for p in re.split(r'\n\s*\n', text.strip()) if p.strip()]
149
+
150
+ # =====================================================
151
  # ✅ FastAPI Setup
152
+ # =====================================================
153
+ app = FastAPI(title="ModernBERT AI Text Detector")
154
+
155
+ class InputText(BaseModel):
156
+ text: str
157
+
158
+ @app.get("/health")
159
+ async def health():
160
+ return {"status": "ok"}
161
+
162
+
163
+
164
+
165
+
166
+
167
+
168
+
169
+
170
+
171
  @app.post("/analyze")
172
+ async def analyze(data: InputText):
173
+ text = data.text.strip()
174
+ if not text:
175
+ return {"success": False, "code": 400, "message": "Empty input text"}
176
 
177
+ total_words = len(text.split())
178
+ full_result = analyze_text_block(text)
179
+ fake_percentage = round(full_result["ai_generated_score"] * 100, 2)
180
+ ai_words = int(total_words * (fake_percentage / 100))
181
+ results = []
182
 
183
+ if fake_percentage > 50:
184
+ paragraphs = split_into_paragraphs(text)
185
+ ai_words, total_words = 0, 0
186
+ for p in paragraphs:
187
+ res = analyze_text_block(p)
188
+ wc = len(p.split())
189
+ total_words += wc
190
+ ai_words += wc * res["ai_generated_score"]
191
+ results.append({
192
+ "paragraph": p,
193
+ "ai_generated_score": res["ai_generated_score"],
194
+ "human_written_score": res["human_written_score"],
195
+ "predicted_model": res["predicted_model"]
196
+ })
197
+ fake_percentage = round((ai_words / total_words) * 100, 2)
198
+
199
+ feedback = (
200
+ "Most of Your Text is AI/GPT Generated"
201
+ if fake_percentage > 50
202
+ else "Most of Your Text Appears Human-Written"
203
+ )
204
+
205
+ return {
206
+ "success": True,
207
+ "code": 200,
208
+ "message": "analysis completed",
209
+ "data": {
210
+ "fakePercentage": fake_percentage,
211
+ "isHuman": round(100 - fake_percentage, 2),
212
+ "textWords": total_words,
213
+ "aiWords": ai_words,
214
+ "paragraphs": results,
215
+ "predicted_model": full_result["predicted_model"],
216
+ "feedback": feedback,
217
+ "input_text": text,
218
+ "detected_language": "en"
219
+ }
220
+ }