Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
from fastapi import FastAPI, HTTPException
|
| 2 |
from fastapi.middleware.cors import CORSMiddleware
|
| 3 |
-
from pydantic import BaseModel,
|
| 4 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 5 |
import torch
|
| 6 |
import logging
|
|
@@ -32,23 +32,25 @@ app.add_middleware(
|
|
| 32 |
allow_headers=["*"],
|
| 33 |
)
|
| 34 |
|
| 35 |
-
# ---------------- Models ----------------
|
| 36 |
class ScanRequest(BaseModel):
|
| 37 |
text: str
|
| 38 |
-
# Accept both scan_type and scanType
|
| 39 |
scan_type: Optional[str] = None
|
| 40 |
scanType: Optional[str] = None
|
| 41 |
userId: Optional[str] = None
|
| 42 |
|
| 43 |
-
@
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
|
|
|
|
|
|
| 48 |
return v
|
| 49 |
-
|
| 50 |
def get_scan_type(self) -> str:
|
| 51 |
-
"""Get the scan type, defaulting to 'basic' if not provided"""
|
|
|
|
| 52 |
return self.scan_type or "basic"
|
| 53 |
|
| 54 |
class ScanResponse(BaseModel):
|
|
@@ -58,7 +60,7 @@ class ScanResponse(BaseModel):
|
|
| 58 |
credits: Optional[dict] = None
|
| 59 |
test_mode: bool = False
|
| 60 |
|
| 61 |
-
# ---------------- AI Detector ----------------
|
| 62 |
MODEL_NAME = "openai-community/roberta-large-openai-detector"
|
| 63 |
|
| 64 |
class AIDetector:
|
|
@@ -93,7 +95,7 @@ class AIDetector:
|
|
| 93 |
logger.info("Model loaded successfully.")
|
| 94 |
|
| 95 |
def predict(self, text: str, max_length: int = 512) -> dict:
|
| 96 |
-
"""Return both human and AI probabilities
|
| 97 |
if self.model is None:
|
| 98 |
self.load_model()
|
| 99 |
|
|
@@ -116,9 +118,7 @@ class AIDetector:
|
|
| 116 |
ai_prob = float(probs[0][1].item()) # Class 1
|
| 117 |
|
| 118 |
# Debug logging
|
| 119 |
-
logger.debug(f"
|
| 120 |
-
logger.debug(f"Class 0 (Human): {human_prob:.4f}")
|
| 121 |
-
logger.debug(f"Class 1 (AI): {ai_prob:.4f}")
|
| 122 |
|
| 123 |
# Verify probabilities sum to ~1.0
|
| 124 |
total = human_prob + ai_prob
|
|
@@ -133,9 +133,9 @@ class AIDetector:
|
|
| 133 |
|
| 134 |
detector = AIDetector()
|
| 135 |
|
| 136 |
-
# ----------------
|
| 137 |
def detect_chatgpt_patterns(text: str) -> bool:
|
| 138 |
-
"""Return True if ChatGPT patterns are detected"""
|
| 139 |
patterns = [
|
| 140 |
"as an ai language model",
|
| 141 |
"i am an ai model",
|
|
@@ -237,10 +237,10 @@ def compute_overall_score(sections: List[dict], confidence_threshold: float = 0.
|
|
| 237 |
"confident_sections": len(confident_sections)
|
| 238 |
}
|
| 239 |
|
| 240 |
-
# ---------------- Endpoints ----------------
|
| 241 |
@app.on_event("startup")
|
| 242 |
async def startup():
|
| 243 |
-
"""Initialize the model on startup"""
|
| 244 |
logger.info("Starting Detextly AI Detector API...")
|
| 245 |
try:
|
| 246 |
detector.load_model()
|
|
@@ -257,22 +257,21 @@ async def root():
|
|
| 257 |
"device": str(detector.device),
|
| 258 |
"version": "2.1.0",
|
| 259 |
"features": ["basic_scan", "highlight_scan", "chatgpt_pattern_detection"],
|
| 260 |
-
"
|
| 261 |
}
|
| 262 |
|
| 263 |
@app.get("/health")
|
| 264 |
async def health():
|
| 265 |
-
|
| 266 |
"status": "healthy",
|
| 267 |
"model_loaded": detector.model is not None,
|
| 268 |
"model": MODEL_NAME,
|
| 269 |
"timestamp": time.time()
|
| 270 |
}
|
| 271 |
-
return health_status
|
| 272 |
|
| 273 |
@app.get("/debug/test")
|
| 274 |
async def debug_test():
|
| 275 |
-
"""Test endpoint to verify model is working correctly"""
|
| 276 |
test_texts = [
|
| 277 |
"I went to the store yesterday to buy groceries.",
|
| 278 |
"As an AI language model, I don't have personal experiences.",
|
|
@@ -299,7 +298,7 @@ async def debug_test():
|
|
| 299 |
|
| 300 |
@app.post("/api/scan", response_model=ScanResponse)
|
| 301 |
async def scan_text(request: ScanRequest):
|
| 302 |
-
"""Main scanning endpoint"""
|
| 303 |
start_time = time.time()
|
| 304 |
|
| 305 |
try:
|
|
@@ -307,7 +306,7 @@ async def scan_text(request: ScanRequest):
|
|
| 307 |
if not request.text or len(request.text.strip()) < 10:
|
| 308 |
raise HTTPException(status_code=400, detail="Text must be at least 10 characters long.")
|
| 309 |
|
| 310 |
-
# Get scan type (handles both scan_type and scanType)
|
| 311 |
scan_type = request.get_scan_type()
|
| 312 |
logger.info(f"Scan request: type={scan_type}, userId={request.userId}, text_length={len(request.text)}")
|
| 313 |
|
|
@@ -398,7 +397,7 @@ async def scan_text(request: ScanRequest):
|
|
| 398 |
|
| 399 |
@app.get("/api/credits")
|
| 400 |
async def get_credits(userId: Optional[str] = None):
|
| 401 |
-
"""Get credits information (for compatibility with worker)"""
|
| 402 |
return {
|
| 403 |
"basic": 5,
|
| 404 |
"highlight": 1,
|
|
|
|
| 1 |
from fastapi import FastAPI, HTTPException
|
| 2 |
from fastapi.middleware.cors import CORSMiddleware
|
| 3 |
+
from pydantic import BaseModel, field_validator, ValidationInfo
|
| 4 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 5 |
import torch
|
| 6 |
import logging
|
|
|
|
| 32 |
allow_headers=["*"],
|
| 33 |
)
|
| 34 |
|
| 35 |
+
# ---------------- Pydantic Models ----------------
|
| 36 |
class ScanRequest(BaseModel):
|
| 37 |
text: str
|
|
|
|
| 38 |
scan_type: Optional[str] = None
|
| 39 |
scanType: Optional[str] = None
|
| 40 |
userId: Optional[str] = None
|
| 41 |
|
| 42 |
+
@field_validator('scanType')
|
| 43 |
+
@classmethod
|
| 44 |
+
def map_scantype_to_scan_type(cls, v: Optional[str], info: ValidationInfo) -> Optional[str]:
|
| 45 |
+
"""Mapper to ensure backward compatibility with old 'scanType' parameter name."""
|
| 46 |
+
if v is not None:
|
| 47 |
+
# Map the old 'scanType' field value to the new 'scan_type' field
|
| 48 |
+
info.data['scan_type'] = v
|
| 49 |
return v
|
| 50 |
+
|
| 51 |
def get_scan_type(self) -> str:
|
| 52 |
+
"""Get the scan type, defaulting to 'basic' if not provided."""
|
| 53 |
+
# scan_type takes precedence as it's the canonical field name
|
| 54 |
return self.scan_type or "basic"
|
| 55 |
|
| 56 |
class ScanResponse(BaseModel):
|
|
|
|
| 60 |
credits: Optional[dict] = None
|
| 61 |
test_mode: bool = False
|
| 62 |
|
| 63 |
+
# ---------------- AI Detector Core ----------------
|
| 64 |
MODEL_NAME = "openai-community/roberta-large-openai-detector"
|
| 65 |
|
| 66 |
class AIDetector:
|
|
|
|
| 95 |
logger.info("Model loaded successfully.")
|
| 96 |
|
| 97 |
def predict(self, text: str, max_length: int = 512) -> dict:
|
| 98 |
+
"""Return both human and AI probabilities."""
|
| 99 |
if self.model is None:
|
| 100 |
self.load_model()
|
| 101 |
|
|
|
|
| 118 |
ai_prob = float(probs[0][1].item()) # Class 1
|
| 119 |
|
| 120 |
# Debug logging
|
| 121 |
+
logger.debug(f"Class 0 (Human): {human_prob:.4f}, Class 1 (AI): {ai_prob:.4f}")
|
|
|
|
|
|
|
| 122 |
|
| 123 |
# Verify probabilities sum to ~1.0
|
| 124 |
total = human_prob + ai_prob
|
|
|
|
| 133 |
|
| 134 |
detector = AIDetector()
|
| 135 |
|
| 136 |
+
# ---------------- Pattern Detection ----------------
|
| 137 |
def detect_chatgpt_patterns(text: str) -> bool:
|
| 138 |
+
"""Return True if ChatGPT patterns are detected."""
|
| 139 |
patterns = [
|
| 140 |
"as an ai language model",
|
| 141 |
"i am an ai model",
|
|
|
|
| 237 |
"confident_sections": len(confident_sections)
|
| 238 |
}
|
| 239 |
|
| 240 |
+
# ---------------- API Endpoints ----------------
|
| 241 |
@app.on_event("startup")
|
| 242 |
async def startup():
|
| 243 |
+
"""Initialize the model on startup."""
|
| 244 |
logger.info("Starting Detextly AI Detector API...")
|
| 245 |
try:
|
| 246 |
detector.load_model()
|
|
|
|
| 257 |
"device": str(detector.device),
|
| 258 |
"version": "2.1.0",
|
| 259 |
"features": ["basic_scan", "highlight_scan", "chatgpt_pattern_detection"],
|
| 260 |
+
"note": "Accepts both 'scan_type' and 'scanType' parameters"
|
| 261 |
}
|
| 262 |
|
| 263 |
@app.get("/health")
|
| 264 |
async def health():
|
| 265 |
+
return {
|
| 266 |
"status": "healthy",
|
| 267 |
"model_loaded": detector.model is not None,
|
| 268 |
"model": MODEL_NAME,
|
| 269 |
"timestamp": time.time()
|
| 270 |
}
|
|
|
|
| 271 |
|
| 272 |
@app.get("/debug/test")
|
| 273 |
async def debug_test():
|
| 274 |
+
"""Test endpoint to verify model is working correctly."""
|
| 275 |
test_texts = [
|
| 276 |
"I went to the store yesterday to buy groceries.",
|
| 277 |
"As an AI language model, I don't have personal experiences.",
|
|
|
|
| 298 |
|
| 299 |
@app.post("/api/scan", response_model=ScanResponse)
|
| 300 |
async def scan_text(request: ScanRequest):
|
| 301 |
+
"""Main scanning endpoint."""
|
| 302 |
start_time = time.time()
|
| 303 |
|
| 304 |
try:
|
|
|
|
| 306 |
if not request.text or len(request.text.strip()) < 10:
|
| 307 |
raise HTTPException(status_code=400, detail="Text must be at least 10 characters long.")
|
| 308 |
|
| 309 |
+
# Get scan type (handles both scan_type and scanType via the validator)
|
| 310 |
scan_type = request.get_scan_type()
|
| 311 |
logger.info(f"Scan request: type={scan_type}, userId={request.userId}, text_length={len(request.text)}")
|
| 312 |
|
|
|
|
| 397 |
|
| 398 |
@app.get("/api/credits")
|
| 399 |
async def get_credits(userId: Optional[str] = None):
|
| 400 |
+
"""Get credits information (for compatibility with worker)."""
|
| 401 |
return {
|
| 402 |
"basic": 5,
|
| 403 |
"highlight": 1,
|