File size: 2,450 Bytes
e0f2d0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import joblib
import re
import string
from typing import Any
from nltk.corpus import stopwords


from schemas.text_schemas import AITextDetector


class NBAITextDetector(AITextDetector):
    """

    Naive Bayes AI Text Detector that classifies whether a text

    is AI-generated or human-written using a pre-trained joblib model.



    Attributes:

        model_path (str): Path to the saved joblib model.

        model (Any): Loaded ML model for prediction.

        min_words (int): Minimum number of words required for valid detection.

    """

    def __init__(self, model_path: str = "./models/ai_text_detector.joblib", min_words: int = 0):
        """

        Initialize the NBAITextDetector.



        Args:

            model_path (str, optional): Path to the trained joblib model file.

            min_words (int, optional): Minimum number of words for detection.

        """
        self.model_path = model_path
        self.min_words = min_words
        self.model = self._load_model()
        self.stop_words = set(stopwords.words('english'))

    def _load_model(self) -> Any:
        """Load the trained joblib model."""
        return joblib.load(self.model_path)

    def _preprocess_text(self, text: str) -> str:
        """Clean and preprocess the input text."""
        text = text.lower()
        text = re.sub(r'\s+', ' ', text.strip())
        text = text.translate(str.maketrans('', '', string.punctuation))
        text = ' '.join(word for word in text.split() if word not in self.stop_words)
        text = re.sub(r'http\S+|www\.\S+', '', text)
        text = re.sub(r'\S+@\S+\.\S+', '', text)
        text = re.sub(r'#[A-Za-z0-9_]+', '', text)
        text = re.sub(r'@[A-Za-z0-9_]+', '', text)
        text = re.sub(r'\d+', '', text)
        text = ''.join(ch for ch in text if ch.isprintable())
        return text

    def detect(self, text: str) -> bool:
        """

        Detect whether a given text is AI-generated.



        Args:

            text (str): Input text to classify.



        Returns:

            bool: True if AI-generated, False if human-written.

        """
        if len(text.split()) < self.min_words:
            raise ValueError(f"Text must be at least {self.min_words} words long.")

        processed = self._preprocess_text(text)
        prediction = self.model.predict([processed])
        return bool(int(prediction[0]) == 1)