File size: 3,197 Bytes
b25f064
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f8daf8
b25f064
3f8daf8
 
b25f064
 
 
 
 
 
 
 
3f8daf8
 
 
b25f064
 
3f8daf8
 
 
 
 
 
a147390
3f8daf8
 
 
 
 
 
 
 
b25f064
3f8daf8
 
b25f064
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import os
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import logging

logger = logging.getLogger(__name__)

class AIDetectorModel:
    def __init__(self, model_dir: str, model_filename: str):
        self.model_dir = model_dir
        self.model_path = os.path.join(model_dir, model_filename)
        self.tokenizer = None
        self.model = None
        self.device = torch.device("cpu") # Force CPU as requested

    def load(self):
        """Loads the tokenizer and model. Tries local quantized first, falls back to HF base model."""
        logger.info(f"Loading AI Detector model from: {self.model_path}...")
        
        # 1. Try loading local quantized model
        try:
            # Load tokenizer from the directory (saved by download_model.py)
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
            
            # Load the full quantized model object
            try:
                self.model = torch.load(self.model_path, map_location=self.device, weights_only=False)
            except TypeError:
                # Fallback for older torch versions
                self.model = torch.load(self.model_path, map_location=self.device)
            
            self.model.eval()
            logger.info("AI Detector quantized model loaded successfully.")
            return

        except Exception as e:
            logger.warning(f"Failed to load quantized model: {e}. Attempting fallback...")

        # 2. Fallback to base Hugging Face model
        fallback_tag = "yuchuantian/AIGC_detector_env3short"
        logger.info(f"Loading fallback model: {fallback_tag}...")
        
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(fallback_tag)
            self.model = AutoModelForSequenceClassification.from_pretrained(fallback_tag)
            self.model.to(self.device)
            self.model.eval()
            logger.info("Fallback model loaded successfully.")
        except Exception as e:
            logger.error(f"Failed to load fallback model: {e}")
            raise RuntimeError(f"Failed to load AI Detector model (both quantized and fallback failed): {e}")

    def predict(self, text: str):
        """
        Runs inference on the text.
        Returns a dictionary with label, confidence score, and raw probabilities.
        """
        if self.model is None or self.tokenizer is None:
            raise RuntimeError("Model is not loaded.")

        inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(self.device)

        with torch.no_grad():
            outputs = self.model(**inputs)
            logits = outputs.logits
            probabilities = torch.softmax(logits, dim=1)
            predicted_class = torch.argmax(probabilities, dim=1).item()

        # Labels: 0 -> Human, 1 -> AI
        labels_map = {0: "Human", 1: "AI"}
        label = labels_map.get(predicted_class, "Unknown")
        confidence = probabilities[0][predicted_class].item()
        
        return {
            "label": label,
            "score": confidence,
            "probabilities": probabilities.cpu().numpy().tolist()[0]
        }