RafzE commited on
Commit
1583931
·
verified ·
1 Parent(s): f7d6571

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -26
app.py CHANGED
@@ -1,6 +1,6 @@
1
  from fastapi import FastAPI, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
- from pydantic import BaseModel, validator
4
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
  import torch
6
  import logging
@@ -32,23 +32,25 @@ app.add_middleware(
32
  allow_headers=["*"],
33
  )
34
 
35
- # ---------------- Models ----------------
36
  class ScanRequest(BaseModel):
37
  text: str
38
- # Accept both scan_type and scanType
39
  scan_type: Optional[str] = None
40
  scanType: Optional[str] = None
41
  userId: Optional[str] = None
42
 
43
- @validator('scan_type', 'scanType', pre=True, always=True)
44
- def determine_scan_type(cls, v, values, field):
45
- if field.name == 'scanType' and v:
46
- # Map scanType to scan_type for internal use
47
- values['scan_type'] = v
 
 
48
  return v
49
-
50
  def get_scan_type(self) -> str:
51
- """Get the scan type, defaulting to 'basic' if not provided"""
 
52
  return self.scan_type or "basic"
53
 
54
  class ScanResponse(BaseModel):
@@ -58,7 +60,7 @@ class ScanResponse(BaseModel):
58
  credits: Optional[dict] = None
59
  test_mode: bool = False
60
 
61
- # ---------------- AI Detector ----------------
62
  MODEL_NAME = "openai-community/roberta-large-openai-detector"
63
 
64
  class AIDetector:
@@ -93,7 +95,7 @@ class AIDetector:
93
  logger.info("Model loaded successfully.")
94
 
95
  def predict(self, text: str, max_length: int = 512) -> dict:
96
- """Return both human and AI probabilities with debugging info"""
97
  if self.model is None:
98
  self.load_model()
99
 
@@ -116,9 +118,7 @@ class AIDetector:
116
  ai_prob = float(probs[0][1].item()) # Class 1
117
 
118
  # Debug logging
119
- logger.debug(f"Raw probabilities: {probs}")
120
- logger.debug(f"Class 0 (Human): {human_prob:.4f}")
121
- logger.debug(f"Class 1 (AI): {ai_prob:.4f}")
122
 
123
  # Verify probabilities sum to ~1.0
124
  total = human_prob + ai_prob
@@ -133,9 +133,9 @@ class AIDetector:
133
 
134
  detector = AIDetector()
135
 
136
- # ---------------- ChatGPT Pattern Detection ----------------
137
  def detect_chatgpt_patterns(text: str) -> bool:
138
- """Return True if ChatGPT patterns are detected"""
139
  patterns = [
140
  "as an ai language model",
141
  "i am an ai model",
@@ -237,10 +237,10 @@ def compute_overall_score(sections: List[dict], confidence_threshold: float = 0.
237
  "confident_sections": len(confident_sections)
238
  }
239
 
240
- # ---------------- Endpoints ----------------
241
  @app.on_event("startup")
242
  async def startup():
243
- """Initialize the model on startup"""
244
  logger.info("Starting Detextly AI Detector API...")
245
  try:
246
  detector.load_model()
@@ -257,22 +257,21 @@ async def root():
257
  "device": str(detector.device),
258
  "version": "2.1.0",
259
  "features": ["basic_scan", "highlight_scan", "chatgpt_pattern_detection"],
260
- "endpoints": ["POST /api/scan", "GET /health", "GET /debug/test"]
261
  }
262
 
263
  @app.get("/health")
264
  async def health():
265
- health_status = {
266
  "status": "healthy",
267
  "model_loaded": detector.model is not None,
268
  "model": MODEL_NAME,
269
  "timestamp": time.time()
270
  }
271
- return health_status
272
 
273
  @app.get("/debug/test")
274
  async def debug_test():
275
- """Test endpoint to verify model is working correctly"""
276
  test_texts = [
277
  "I went to the store yesterday to buy groceries.",
278
  "As an AI language model, I don't have personal experiences.",
@@ -299,7 +298,7 @@ async def debug_test():
299
 
300
  @app.post("/api/scan", response_model=ScanResponse)
301
  async def scan_text(request: ScanRequest):
302
- """Main scanning endpoint"""
303
  start_time = time.time()
304
 
305
  try:
@@ -307,7 +306,7 @@ async def scan_text(request: ScanRequest):
307
  if not request.text or len(request.text.strip()) < 10:
308
  raise HTTPException(status_code=400, detail="Text must be at least 10 characters long.")
309
 
310
- # Get scan type (handles both scan_type and scanType)
311
  scan_type = request.get_scan_type()
312
  logger.info(f"Scan request: type={scan_type}, userId={request.userId}, text_length={len(request.text)}")
313
 
@@ -398,7 +397,7 @@ async def scan_text(request: ScanRequest):
398
 
399
  @app.get("/api/credits")
400
  async def get_credits(userId: Optional[str] = None):
401
- """Get credits information (for compatibility with worker)"""
402
  return {
403
  "basic": 5,
404
  "highlight": 1,
 
1
  from fastapi import FastAPI, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
+ from pydantic import BaseModel, field_validator, ValidationInfo
4
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
  import torch
6
  import logging
 
32
  allow_headers=["*"],
33
  )
34
 
35
+ # ---------------- Pydantic Models ----------------
36
  class ScanRequest(BaseModel):
37
  text: str
 
38
  scan_type: Optional[str] = None
39
  scanType: Optional[str] = None
40
  userId: Optional[str] = None
41
 
42
+ @field_validator('scanType')
43
+ @classmethod
44
+ def map_scantype_to_scan_type(cls, v: Optional[str], info: ValidationInfo) -> Optional[str]:
45
+ """Mapper to ensure backward compatibility with old 'scanType' parameter name."""
46
+ if v is not None:
47
+ # Map the old 'scanType' field value to the new 'scan_type' field
48
+ info.data['scan_type'] = v
49
  return v
50
+
51
  def get_scan_type(self) -> str:
52
+ """Get the scan type, defaulting to 'basic' if not provided."""
53
+ # scan_type takes precedence as it's the canonical field name
54
  return self.scan_type or "basic"
55
 
56
  class ScanResponse(BaseModel):
 
60
  credits: Optional[dict] = None
61
  test_mode: bool = False
62
 
63
+ # ---------------- AI Detector Core ----------------
64
  MODEL_NAME = "openai-community/roberta-large-openai-detector"
65
 
66
  class AIDetector:
 
95
  logger.info("Model loaded successfully.")
96
 
97
  def predict(self, text: str, max_length: int = 512) -> dict:
98
+ """Return both human and AI probabilities."""
99
  if self.model is None:
100
  self.load_model()
101
 
 
118
  ai_prob = float(probs[0][1].item()) # Class 1
119
 
120
  # Debug logging
121
+ logger.debug(f"Class 0 (Human): {human_prob:.4f}, Class 1 (AI): {ai_prob:.4f}")
 
 
122
 
123
  # Verify probabilities sum to ~1.0
124
  total = human_prob + ai_prob
 
133
 
134
  detector = AIDetector()
135
 
136
+ # ---------------- Pattern Detection ----------------
137
  def detect_chatgpt_patterns(text: str) -> bool:
138
+ """Return True if ChatGPT patterns are detected."""
139
  patterns = [
140
  "as an ai language model",
141
  "i am an ai model",
 
237
  "confident_sections": len(confident_sections)
238
  }
239
 
240
+ # ---------------- API Endpoints ----------------
241
  @app.on_event("startup")
242
  async def startup():
243
+ """Initialize the model on startup."""
244
  logger.info("Starting Detextly AI Detector API...")
245
  try:
246
  detector.load_model()
 
257
  "device": str(detector.device),
258
  "version": "2.1.0",
259
  "features": ["basic_scan", "highlight_scan", "chatgpt_pattern_detection"],
260
+ "note": "Accepts both 'scan_type' and 'scanType' parameters"
261
  }
262
 
263
  @app.get("/health")
264
  async def health():
265
+ return {
266
  "status": "healthy",
267
  "model_loaded": detector.model is not None,
268
  "model": MODEL_NAME,
269
  "timestamp": time.time()
270
  }
 
271
 
272
  @app.get("/debug/test")
273
  async def debug_test():
274
+ """Test endpoint to verify model is working correctly."""
275
  test_texts = [
276
  "I went to the store yesterday to buy groceries.",
277
  "As an AI language model, I don't have personal experiences.",
 
298
 
299
  @app.post("/api/scan", response_model=ScanResponse)
300
  async def scan_text(request: ScanRequest):
301
+ """Main scanning endpoint."""
302
  start_time = time.time()
303
 
304
  try:
 
306
  if not request.text or len(request.text.strip()) < 10:
307
  raise HTTPException(status_code=400, detail="Text must be at least 10 characters long.")
308
 
309
+ # Get scan type (handles both scan_type and scanType via the validator)
310
  scan_type = request.get_scan_type()
311
  logger.info(f"Scan request: type={scan_type}, userId={request.userId}, text_length={len(request.text)}")
312
 
 
397
 
398
  @app.get("/api/credits")
399
  async def get_credits(userId: Optional[str] = None):
400
+ """Get credits information (for compatibility with worker)."""
401
  return {
402
  "basic": 5,
403
  "highlight": 1,