Latest-app / src /detector.py
abhi099k's picture
Update src/detector.py
09abcdc verified
import torch
from transformers import AutoTokenizer, AutoConfig, AutoModelForSequenceClassification
import numpy as np
import re
import os
import time
from pathlib import Path
# Configure cache for Hugging Face Spaces
os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers_cache'
os.environ['HF_HOME'] = '/tmp/huggingface'
# Create cache directories
Path('/tmp/transformers_cache').mkdir(parents=True, exist_ok=True)
Path('/tmp/huggingface').mkdir(parents=True, exist_ok=True)
MODEL_DIR = "abhi099k/ai-text-detector-v-n4.0"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize as None, load on first use
_tokenizer = None
_config = None
_model = None
def get_components():
"""Lazy load model components with retry logic"""
global _tokenizer, _config, _model
if _tokenizer is None:
max_retries = 3
for attempt in range(max_retries):
try:
print(f"Loading model components... (Attempt {attempt + 1}/{max_retries})")
_tokenizer = AutoTokenizer.from_pretrained(
MODEL_DIR,
cache_dir='/tmp/transformers_cache',
local_files_only=False
)
_config = AutoConfig.from_pretrained(
MODEL_DIR,
cache_dir='/tmp/transformers_cache',
local_files_only=False
)
_model = AutoModelForSequenceClassification.from_pretrained(
MODEL_DIR,
config=_config,
cache_dir='/tmp/transformers_cache',
local_files_only=False
).to(device)
_model.eval()
print("Model loaded successfully!")
break
except OSError as e:
if attempt < max_retries - 1:
wait_time = (attempt + 1) * 2
print(f"Cache conflict detected, retrying in {wait_time} seconds...")
time.sleep(wait_time)
# Try to clear any lock files
cache_path = Path('/tmp/transformers_cache')
if cache_path.exists():
for lock_file in cache_path.glob("*.lock"):
try:
lock_file.unlink()
print(f"Removed lock file: {lock_file}")
except:
pass
else:
print(f"Failed to load model after {max_retries} attempts: {e}")
raise
return _tokenizer, _config, _model
# === Preprocessing: Normalize + Flatten ===
def preprocess_text_for_detection(text: str) -> str:
"""
Convert structured notes (bullets, lists) into clean sentences for AI detection.
"""
if not text or not isinstance(text, str):
return ""
# Replace bullets / dashes with periods
text = re.sub(r"[\n•\-–]+", ". ", text)
# Remove multiple spaces
text = re.sub(r"\s+", " ", text)
# Ensure consistent punctuation spacing
text = re.sub(r"\s*([,.!?;:])\s*", r"\1 ", text)
return text.strip()
# === Core Scoring ===
def score_text(text, max_len=512):
"""Return AI probability score (float between 0-1) for the text."""
tokenizer, config, model = get_components()
encoded = tokenizer(
text,
padding=True,
truncation=True,
max_length=max_len,
return_tensors="pt"
).to(device)
# Some models may not need token_type_ids
encoded.pop("token_type_ids", None)
with torch.no_grad():
logits = model(**encoded).logits
probs = torch.softmax(logits, dim=-1).cpu().numpy()
# Extract AI probability (label=1)
ai_prob = float(probs[0][1])
return ai_prob
# === Artifact Detection ===
def has_html_or_ai_artifacts(text: str) -> bool:
"""Detect HTML tags or attributes typical of copy-pasted AI output."""
if not text:
return False
html_pattern = re.compile(r'<[^>]+>')
data_attr_pattern = re.compile(r'data-(start|end)=["\']?\d+')
return bool(html_pattern.search(text) or data_attr_pattern.search(text))
# === Main Prediction Function ===
def analyze_text(text, threshold=0.5, chunk_size=80):
"""
Main function to analyze text and detect AI-generated content
Args:
text (str): Input text to analyze
threshold (float): Confidence threshold (0-1)
Returns:
dict: Analysis results
"""
if not text or not text.strip():
return {
"error": "No text provided",
"overall_type": "Unknown",
"overall_confidence": 0.0,
"overall_score": 0.0
}
try:
# Check for AI artifacts
has_artifacts = has_html_or_ai_artifacts(text)
# Preprocess text
processed_text = preprocess_text_for_detection(text)
if not processed_text:
return {
"error": "Text too short or invalid after preprocessing",
"overall_type": "Unknown",
"overall_confidence": 0.0,
"overall_score": 0.0
}
# Score the text
ai_score = score_text(processed_text)
# Determine overall type and confidence
overall_type = "AI" if ai_score >= threshold else "Human"
overall_confidence = ai_score if overall_type == "AI" else (1 - ai_score)
return {
"overall_type": overall_type,
"overall_confidence": float(overall_confidence),
"overall_score": float(ai_score),
"has_artifacts": has_artifacts
}
except Exception as e:
return {
"error": f"Analysis failed: {str(e)}",
"overall_type": "Error",
"overall_confidence": 0.0,
"overall_score": 0.0
}
# Pre-load model when module is imported (optional)
try:
print("Pre-loading model components...")
get_components()
print("Model pre-loaded successfully!")
except Exception as e:
print(f"Pre-loading failed, will load on first use: {e}")