mahmoudsaber0 commited on
Commit
943ecd3
Β·
verified Β·
1 Parent(s): 4f638b9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -192
app.py CHANGED
@@ -1,201 +1,116 @@
1
  import os
2
- import re
3
- import shutil
4
  import torch
5
- from fastapi import FastAPI
6
- from pydantic import BaseModel
 
7
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
8
- from tokenizers.normalizers import Sequence, Replace, Strip
9
- from tokenizers import Regex
10
- import atexit
11
-
12
- # =====================================================
13
- # βœ… Safe Cache Setup (Runtime)
14
- # =====================================================
15
- CACHE_DIR = "/tmp/huggingface"
16
-
17
- # Clear any old cache to prevent exceeding 50G
18
- if os.path.exists(CACHE_DIR):
19
- shutil.rmtree(CACHE_DIR, ignore_errors=True)
20
- os.makedirs(CACHE_DIR, exist_ok=True)
21
-
22
- # Set environment paths
23
- os.environ.update({
24
- "HF_HOME": CACHE_DIR,
25
- "TRANSFORMERS_CACHE": CACHE_DIR,
26
- "HF_DATASETS_CACHE": CACHE_DIR,
27
- "HF_HUB_CACHE": CACHE_DIR,
28
- "TORCH_HOME": CACHE_DIR,
29
- "XDG_CACHE_HOME": CACHE_DIR,
30
- "TORCHINDUCTOR_CACHE_DIR": CACHE_DIR,
31
- "TORCH_LOGS": "off"
32
- })
33
-
34
- # Auto cleanup on shutdown
35
- @atexit.register
36
- def cleanup_cache():
37
- shutil.rmtree(CACHE_DIR, ignore_errors=True)
38
-
39
- # =====================================================
40
- # βœ… Model Setup
41
- # =====================================================
42
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
-
44
- tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")
45
-
46
- # --- Model paths ---
47
  model1_path = "modernbert.bin"
48
- model2_url = "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed12"
49
- model3_url = "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed22"
50
-
51
- def load_model(base_path=None, url=None):
52
- model = AutoModelForSequenceClassification.from_pretrained(
53
- "answerdotai/ModernBERT-base", num_labels=41
54
- )
55
- if url:
56
- state_dict = torch.hub.load_state_dict_from_url(
57
- url, map_location=device, progress=False, check_hash=False
 
 
 
 
 
 
 
 
 
58
  )
59
- else:
60
- state_dict = torch.load(base_path, map_location=device)
61
- model.load_state_dict(state_dict)
62
- model.to(device).eval()
63
- return model
64
-
65
- model_1 = load_model(model1_path)
66
- model_2 = load_model(url=model2_url)
67
- model_3 = load_model(url=model3_url)
68
-
69
- # =====================================================
70
- # βœ… Label Mapping
71
- # =====================================================
72
- label_mapping = {
73
- 0: '13B', 1: '30B', 2: '65B', 3: '7B', 4: 'GLM130B', 5: 'bloom_7b',
74
- 6: 'bloomz', 7: 'cohere', 8: 'davinci', 9: 'dolly', 10: 'dolly-v2-12b',
75
- 11: 'flan_t5_base', 12: 'flan_t5_large', 13: 'flan_t5_small',
76
- 14: 'flan_t5_xl', 15: 'flan_t5_xxl', 16: 'gemma-7b-it', 17: 'gemma2-9b-it',
77
- 18: 'gpt-3.5-turbo', 19: 'gpt-35', 20: 'gpt4', 21: 'gpt4o',
78
- 22: 'gpt_j', 23: 'gpt_neox', 24: 'human', 25: 'llama3-70b', 26: 'llama3-8b',
79
- 27: 'mixtral-8x7b', 28: 'opt_1.3b', 29: 'opt_125m', 30: 'opt_13b',
80
- 31: 'opt_2.7b', 32: 'opt_30b', 33: 'opt_350m', 34: 'opt_6.7b',
81
- 35: 'opt_iml_30b', 36: 'opt_iml_max_1.3b', 37: 't0_11b', 38: 't0_3b',
82
- 39: 'text-davinci-002', 40: 'text-davinci-003'
83
- }
84
-
85
- # =====================================================
86
- # βœ… Text Cleaning & Tokenizer Normalization
87
- # =====================================================
88
- def clean_text(text: str) -> str:
89
- text = re.sub(r'\s{2,}', ' ', text)
90
- text = re.sub(r'\s+([,.;:?!])', r'\1', text)
91
- return text
92
-
93
- newline_to_space = Replace(Regex(r'\s*\n\s*'), " ")
94
- join_hyphen_break = Replace(Regex(r'(\w+)[--]\s*\n\s*(\w+)'), r"\1\2")
95
- tokenizer.backend_tokenizer.normalizer = Sequence([
96
- tokenizer.backend_tokenizer.normalizer,
97
- join_hyphen_break,
98
- newline_to_space,
99
- Strip()
100
- ])
101
-
102
- # =====================================================
103
- # βœ… Analysis Logic
104
- # =====================================================
105
- def analyze_text_block(text: str):
106
- cleaned_text = clean_text(text)
107
- inputs = tokenizer(cleaned_text, return_tensors="pt", truncation=True, padding=True).to(device)
108
-
109
- with torch.no_grad():
110
- logits_1 = model_1(**inputs).logits
111
- logits_2 = model_2(**inputs).logits
112
- logits_3 = model_3(**inputs).logits
113
-
114
- avg_probs = (
115
- torch.softmax(logits_1, dim=1) +
116
- torch.softmax(logits_2, dim=1) +
117
- torch.softmax(logits_3, dim=1)
118
- ) / 3
119
-
120
- probs = avg_probs[0]
121
- human_prob = probs[24].item()
122
- ai_probs = probs.clone()
123
- ai_probs[24] = 0
124
- ai_total_prob = ai_probs.sum().item()
125
-
126
- total = human_prob + ai_total_prob
127
- human_percentage = (human_prob / total) * 100
128
- ai_percentage = (ai_total_prob / total) * 100
129
- ai_model_index = torch.argmax(ai_probs).item()
130
-
131
- return {
132
- "human_written_score": round(human_percentage / 100, 4),
133
- "ai_generated_score": round(ai_percentage / 100, 4),
134
- "predicted_model": label_mapping[ai_model_index]
135
- }
136
-
137
- def split_into_paragraphs(text: str):
138
- return [p.strip() for p in re.split(r'\n\s*\n', text.strip()) if p.strip()]
139
-
140
- # =====================================================
141
  # βœ… FastAPI Setup
142
- # =====================================================
143
- app = FastAPI(title="ModernBERT AI Text Detector")
144
-
145
- class InputText(BaseModel):
146
- text: str
147
-
148
- @app.get("/health")
149
- async def health():
150
- return {"status": "ok"}
151
-
 
 
 
 
 
 
 
 
 
152
  @app.post("/analyze")
153
- async def analyze(data: InputText):
154
- text = data.text.strip()
155
- if not text:
156
- return {"success": False, "code": 400, "message": "Empty input text"}
157
 
158
- total_words = len(text.split())
159
- full_result = analyze_text_block(text)
160
- fake_percentage = round(full_result["ai_generated_score"] * 100, 2)
161
- ai_words = int(total_words * (fake_percentage / 100))
162
- results = []
163
 
164
- if fake_percentage > 50:
165
- paragraphs = split_into_paragraphs(text)
166
- ai_words, total_words = 0, 0
167
- for p in paragraphs:
168
- res = analyze_text_block(p)
169
- wc = len(p.split())
170
- total_words += wc
171
- ai_words += wc * res["ai_generated_score"]
172
- results.append({
173
- "paragraph": p,
174
- "ai_generated_score": res["ai_generated_score"],
175
- "human_written_score": res["human_written_score"],
176
- "predicted_model": res["predicted_model"]
177
- })
178
- fake_percentage = round((ai_words / total_words) * 100, 2)
179
-
180
- feedback = (
181
- "Most of Your Text is AI/GPT Generated"
182
- if fake_percentage > 50
183
- else "Most of Your Text Appears Human-Written"
184
- )
185
-
186
- return {
187
- "success": True,
188
- "code": 200,
189
- "message": "analysis completed",
190
- "data": {
191
- "fakePercentage": fake_percentage,
192
- "isHuman": round(100 - fake_percentage, 2),
193
- "textWords": total_words,
194
- "aiWords": ai_words,
195
- "paragraphs": results,
196
- "predicted_model": full_result["predicted_model"],
197
- "feedback": feedback,
198
- "input_text": text,
199
- "detected_language": "en"
200
- }
201
- }
 
1
  import os
 
 
2
  import torch
3
+ from fastapi import FastAPI, WebSocket, UploadFile, File
4
+ from fastapi.responses import JSONResponse
5
+ from fastapi.middleware.cors import CORSMiddleware
6
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
7
+ from huggingface_hub import hf_hub_download
8
+
9
+ # ==========================================================
10
+ # βœ… Environment Setup (Safe Caching)
11
+ # ==========================================================
12
+ os.environ["HF_HOME"] = "/tmp/huggingface"
13
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
14
+ os.environ["HF_HUB_CACHE"] = "/tmp/huggingface"
15
+ os.makedirs("/tmp/huggingface", exist_ok=True)
16
+ os.makedirs("/tmp/models", exist_ok=True)
17
+
18
+ # ==========================================================
19
+ # βœ… Model Paths
20
+ # ==========================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  model1_path = "modernbert.bin"
22
+ model2_path = "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed12"
23
+ model3_path = "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed22"
24
+
25
+ # ==========================================================
26
+ # βœ… Helper: Load Model (Local or Remote)
27
+ # ==========================================================
28
+ def load_model(path_or_url, model_name="modernbert"):
29
+ try:
30
+ if path_or_url.startswith("http"):
31
+ print(f"Downloading model from {path_or_url}...")
32
+ filename = os.path.join("/tmp/models", os.path.basename(path_or_url))
33
+ if not os.path.exists(filename):
34
+ torch.hub.download_url_to_file(path_or_url, filename)
35
+ path_or_url = filename
36
+
37
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
38
+ model = AutoModelForSequenceClassification.from_pretrained(
39
+ "bert-base-uncased",
40
+ state_dict=torch.load(path_or_url, map_location="cpu"),
41
  )
42
+ model.eval()
43
+ print(f"βœ… Loaded model: {path_or_url}")
44
+ return tokenizer, model
45
+ except Exception as e:
46
+ print(f"❌ Failed to load model from {path_or_url}: {e}")
47
+ return None, None
48
+
49
+ # ==========================================================
50
+ # βœ… Load All Models
51
+ # ==========================================================
52
+ tokenizer1, model1 = load_model(model1_path, "model1")
53
+ tokenizer2, model2 = load_model(model2_path, "model2")
54
+ tokenizer3, model3 = load_model(model3_path, "model3")
55
+
56
+ # ==========================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  # βœ… FastAPI Setup
58
+ # ==========================================================
59
+ app = FastAPI(title="ModernBERT AI Detection API")
60
+
61
+ app.add_middleware(
62
+ CORSMiddleware,
63
+ allow_origins=["*"], allow_credentials=True,
64
+ allow_methods=["*"], allow_headers=["*"]
65
+ )
66
+
67
+ # ==========================================================
68
+ # βœ… Root Route
69
+ # ==========================================================
70
+ @app.get("/")
71
+ def home():
72
+ return {"message": "ModernBERT API is running"}
73
+
74
+ # ==========================================================
75
+ # βœ… Text Inference Endpoint
76
+ # ==========================================================
77
  @app.post("/analyze")
78
+ async def analyze_text(data: dict):
79
+ text = data.get("text", "")
80
+ if not text.strip():
81
+ return JSONResponse({"error": "No text provided"}, status_code=400)
82
 
83
+ inputs = tokenizer1(text, return_tensors="pt", truncation=True, padding=True)
84
+ results = {}
 
 
 
85
 
86
+ with torch.no_grad():
87
+ if model1:
88
+ results["model1"] = torch.softmax(model1(**inputs).logits, dim=-1).tolist()[0]
89
+ if model2:
90
+ results["model2"] = torch.softmax(model2(**inputs).logits, dim=-1).tolist()[0]
91
+ if model3:
92
+ results["model3"] = torch.softmax(model3(**inputs).logits, dim=-1).tolist()[0]
93
+
94
+ avg_score = torch.tensor([results[m][0] for m in results]).mean().item()
95
+ return {"results": results, "avg_score": avg_score}
96
+
97
+ # ==========================================================
98
+ # βœ… WebSocket for Real-Time Scoring
99
+ # ==========================================================
100
+ @app.websocket("/ws")
101
+ async def websocket_endpoint(websocket: WebSocket):
102
+ await websocket.accept()
103
+ try:
104
+ while True:
105
+ data = await websocket.receive_text()
106
+ inputs = tokenizer1(data, return_tensors="pt", truncation=True, padding=True)
107
+ with torch.no_grad():
108
+ output = torch.softmax(model1(**inputs).logits, dim=-1)
109
+ await websocket.send_json({"score": output.tolist()[0]})
110
+ except Exception:
111
+ await websocket.close()
112
+
113
+ # ==========================================================
114
+ # βœ… Run Command
115
+ # ==========================================================
116
+ # uvicorn app:app --host 0.0.0.0 --