File size: 6,528 Bytes
13464bf
8eaaeee
 
 
 
 
 
 
 
 
 
 
 
 
21abb82
8eaaeee
21abb82
8eaaeee
b4746b6
21abb82
8eaaeee
b4746b6
 
8eaaeee
 
b4746b6
8eaaeee
13464bf
8eaaeee
 
 
 
 
 
 
 
 
 
 
13464bf
 
 
 
 
8eaaeee
13464bf
 
8eaaeee
 
 
 
 
 
 
 
b4746b6
8eaaeee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4746b6
13464bf
 
 
 
8eaaeee
 
 
 
 
 
 
 
 
 
 
 
 
 
5efd7a3
8eaaeee
13464bf
8eaaeee
 
 
 
 
 
 
 
707b788
 
 
 
 
 
 
8eaaeee
 
 
 
 
 
 
 
 
 
13464bf
 
5efd7a3
13464bf
8eaaeee
13464bf
8eaaeee
 
 
 
 
 
 
 
 
 
 
5efd7a3
8eaaeee
 
 
 
 
 
 
 
 
 
 
707b788
 
 
 
 
 
 
 
 
8eaaeee
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import torch
import os
import sys
from pathlib import Path
from huggingface_hub import snapshot_download

# Ensure local detree package is importable
# This allows the script to find the 'detree' package if it sits in the same directory
current_dir = os.path.dirname(os.path.abspath(__file__))
if current_dir not in sys.path:
    sys.path.append(current_dir)

try:
    from detree.inference import Detector
except ImportError as e:
    # Fallback if detree is not found (e.g. during initial setup check)
    print(f"Warning: 'detree' package not found. Error: {e}")
    Detector = None


# ── 1) Configuration ────────────────────────────────────────────────────────────
REPO_ID = "MAS-AI-0000/Authentica"
TEXT_SUBFOLDER = "Lib/Models/Text"   # where config.json/model.safetensors live in the repo
EMBEDDING_FILE = "priori1_center10k.pt" 
MAX_LEN = 512

MODEL_DIR = None

try:
    # download a local snapshot of just the Text folder and point MODEL_DIR at it
    print(f"Downloading/Checking model from {REPO_ID}...")
    _snapshot_dir = snapshot_download(
        repo_id=REPO_ID,
        allow_patterns=[f"{TEXT_SUBFOLDER}/*"]
    )
    MODEL_DIR = os.path.join(_snapshot_dir, TEXT_SUBFOLDER)
    print(f"Model directory set to: {MODEL_DIR}")
except Exception as e:
    print(f"Error downloading model from Hugging Face: {e}")

# ── 2) Load model & tokenizer ──────────────────────────────────────────────────
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Text prediction device: {device}")

detector = None

try:
    if Detector and MODEL_DIR:
        database_path = os.path.join(MODEL_DIR, EMBEDDING_FILE)
        
        if not os.path.exists(MODEL_DIR):
            print(f"Warning: Model directory not found at {MODEL_DIR}")
        if not os.path.exists(database_path):
            print(f"Warning: Embedding file not found at {database_path}")
            print(f"Please ensure '{EMBEDDING_FILE}' is present in '{TEXT_SUBFOLDER}' of the Hugging Face repo.")

        # Initialize DETree Detector
        # This loads the model from MODEL_DIR and the embeddings from database_path
        detector = Detector(
            database_path=database_path,
            model_name_or_path=MODEL_DIR,
            device=device,
            max_length=MAX_LEN,
            pooling="max" # Default pooling
        )
        print(f"Text classification model (DETree) loaded successfully")
    else:
        if not Detector:
            print("DETree detector could not be initialized due to missing package.")
        if not MODEL_DIR:
            print("DETree detector could not be initialized due to missing model directory.")

except Exception as e:
    print(f"Error loading text model: {e}")
    print("Text prediction will return fallback responses")

# ── 3) Inference function ──────────────────────────────────────────────────────
def predict_text(text: str, max_length: int = None):
    """
    Predict whether the given text is human-written or AI-generated using DETree.
    
    Args:
        text (str): The text to classify
        max_length (int): Ignored in this implementation as DETree handles it globally, 
                          but kept for compatibility.
        
    Returns:
        dict: Contains predicted_class and confidence
    """
    if detector is None:
        return {"predicted_class": "Human", "confidence": -100.0}
    
    try:
        # detector.predict expects a list of strings
        predictions = detector.predict([text])
        if not predictions:
             return {"predicted_class": "Human", "confidence": -100.0}
        
        pred = predictions[0]
        # pred.label is "Human" or "AI"
        # Map to "Human" or "Ai" to match previous API
        # Determine label based on higher confidence
        if pred.probability_ai > pred.probability_human:
            label = "AI"
            confidence = pred.probability_ai
        else:
            label = "Human"
            confidence = pred.probability_human
        
        # Confidence logic:
        # If label is Human, use probability_human
        # If label is Ai, use probability_ai
        confidence = pred.probability_human if label == "Human" else pred.probability_ai
        
        return {
            "predicted_class": label,
            "confidence": float(confidence)
        }
    except Exception as e:
        print(f"Error during text prediction: {e}")
        return {"predicted_class": "Human", "confidence": -100.0}

# ── 4) Batch prediction ────────────────────────────────────────────────────────
def predict_batch(texts, batch_size=16):
    """
    Predict multiple texts in batches.
    
    Args:
        texts (list): List of text strings to classify
        batch_size (int): Batch size for processing
        
    Returns:
        list: List of prediction dictionaries
    """
    if detector is None:
        return [{"predicted_class": "Human", "confidence": -100.0} for _ in texts]
    
    # Temporarily update batch size if needed, or just use the detector's default
    # We'll update it to respect the argument
    original_batch_size = detector.batch_size
    detector.batch_size = batch_size
    
    try:
        predictions = detector.predict(texts)
        results = []
        for text, pred in zip(texts, predictions):
            label = pred.label
            
            # Determine label based on higher confidence
            if pred.probability_ai > pred.probability_human:
                label = "AI"
                confidence = pred.probability_ai
            else:
                label = "Human"
                confidence = pred.probability_human

            
            results.append({
                "text": text,
                "predicted_class": label,
                "confidence": float(confidence)
            })
        return results
    except Exception as e:
        print(f"Error during batch prediction: {e}")
        return [{"predicted_class": "Human", "confidence": -100.0} for _ in texts]
    finally:
        detector.batch_size = original_batch_size