Spaces:
Running
Running
Update textPreprocess.py
Browse files- textPreprocess.py +130 -66
textPreprocess.py
CHANGED
|
@@ -1,93 +1,157 @@
|
|
| 1 |
-
import os
|
| 2 |
import torch
|
| 3 |
-
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
# ββ 1)
|
| 7 |
REPO_ID = "MAS-AI-0000/Authentica"
|
| 8 |
TEXT_SUBFOLDER = "Lib/Models/Text" # where config.json/model.safetensors live in the repo
|
|
|
|
|
|
|
| 9 |
|
| 10 |
-
|
| 11 |
-
_snapshot_dir = snapshot_download(
|
| 12 |
-
repo_id=REPO_ID,
|
| 13 |
-
allow_patterns=[f"{TEXT_SUBFOLDER}/*"]
|
| 14 |
-
)
|
| 15 |
-
MODEL_DIR = os.path.join(_snapshot_dir, TEXT_SUBFOLDER)
|
| 16 |
-
|
| 17 |
-
# individual file paths (in case you need them elsewhere)
|
| 18 |
-
CONFIG_PATH = os.path.join(MODEL_DIR, "config.json")
|
| 19 |
-
MODEL_SAFETENSORS_PATH = os.path.join(MODEL_DIR, "model.safetensors")
|
| 20 |
-
TOKENIZER_JSON_PATH = os.path.join(MODEL_DIR, "tokenizer.json")
|
| 21 |
-
TOKENIZER_CONFIG_PATH = os.path.join(MODEL_DIR, "tokenizer_config.json")
|
| 22 |
-
SPECIAL_TOKENS_MAP_PATH = os.path.join(MODEL_DIR, "special_tokens_map.json")
|
| 23 |
-
TRAINING_ARGS_BIN_PATH = os.path.join(MODEL_DIR, "training_args.bin") # optional
|
| 24 |
-
TEXT_TXT_PATH = os.path.join(MODEL_DIR, "text.txt") # optional
|
| 25 |
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
# ββ 2) Load model & tokenizer ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 29 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 30 |
print(f"Text prediction device: {device}")
|
| 31 |
|
| 32 |
-
|
| 33 |
-
model = None
|
| 34 |
-
ID2LABEL = {0: "human", 1: "ai"}
|
| 35 |
|
| 36 |
try:
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
-
print("Text classification model loaded successfully")
|
| 48 |
-
print("MODEL_DIR:", MODEL_DIR)
|
| 49 |
-
print("Labels:", ID2LABEL)
|
| 50 |
except Exception as e:
|
| 51 |
print(f"Error loading text model: {e}")
|
| 52 |
print("Text prediction will return fallback responses")
|
| 53 |
|
| 54 |
-
# ββ 3) Inference
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
return {"predicted_class": "Human", "confidence": -100.0}
|
| 60 |
-
|
| 61 |
-
if max_length is None:
|
| 62 |
-
max_length = MAX_LEN
|
| 63 |
-
|
| 64 |
try:
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
except Exception as e:
|
| 73 |
print(f"Error during text prediction: {e}")
|
| 74 |
return {"predicted_class": "Human", "confidence": -100.0}
|
| 75 |
|
| 76 |
-
# ββ 4) Batch
|
| 77 |
-
@torch.inference_mode()
|
| 78 |
def predict_batch(texts, batch_size=16):
|
| 79 |
-
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
return [{"predicted_class": "Human", "confidence": -100.0} for _ in texts]
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from huggingface_hub import snapshot_download
|
| 6 |
+
|
| 7 |
+
# Ensure local detree package is importable
|
| 8 |
+
# This allows the script to find the 'detree' package if it sits in the same directory
|
| 9 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 10 |
+
if current_dir not in sys.path:
|
| 11 |
+
sys.path.append(current_dir)
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
from detree.inference import Detector
|
| 15 |
+
except ImportError:
|
| 16 |
+
# Fallback if detree is not found (e.g. during initial setup check)
|
| 17 |
+
print("Warning: 'detree' package not found. Please ensure the 'detree' folder is in the same directory.")
|
| 18 |
+
Detector = None
|
| 19 |
|
| 20 |
+
# ββ 1) Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 21 |
REPO_ID = "MAS-AI-0000/Authentica"
|
| 22 |
TEXT_SUBFOLDER = "Lib/Models/Text" # where config.json/model.safetensors live in the repo
|
| 23 |
+
EMBEDDING_FILE = "priori1_center10k.pt"
|
| 24 |
+
MAX_LEN = 512
|
| 25 |
|
| 26 |
+
MODEL_DIR = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
+
try:
|
| 29 |
+
# download a local snapshot of just the Text folder and point MODEL_DIR at it
|
| 30 |
+
print(f"Downloading/Checking model from {REPO_ID}...")
|
| 31 |
+
_snapshot_dir = snapshot_download(
|
| 32 |
+
repo_id=REPO_ID,
|
| 33 |
+
allow_patterns=[f"{TEXT_SUBFOLDER}/*"]
|
| 34 |
+
)
|
| 35 |
+
MODEL_DIR = os.path.join(_snapshot_dir, TEXT_SUBFOLDER)
|
| 36 |
+
print(f"Model directory set to: {MODEL_DIR}")
|
| 37 |
+
except Exception as e:
|
| 38 |
+
print(f"Error downloading model from Hugging Face: {e}")
|
| 39 |
|
| 40 |
# ββ 2) Load model & tokenizer ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 41 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 42 |
print(f"Text prediction device: {device}")
|
| 43 |
|
| 44 |
+
detector = None
|
|
|
|
|
|
|
| 45 |
|
| 46 |
try:
|
| 47 |
+
if Detector and MODEL_DIR:
|
| 48 |
+
database_path = os.path.join(MODEL_DIR, EMBEDDING_FILE)
|
| 49 |
+
|
| 50 |
+
if not os.path.exists(MODEL_DIR):
|
| 51 |
+
print(f"Warning: Model directory not found at {MODEL_DIR}")
|
| 52 |
+
if not os.path.exists(database_path):
|
| 53 |
+
print(f"Warning: Embedding file not found at {database_path}")
|
| 54 |
+
print(f"Please ensure '{EMBEDDING_FILE}' is present in '{TEXT_SUBFOLDER}' of the Hugging Face repo.")
|
| 55 |
|
| 56 |
+
# Initialize DETree Detector
|
| 57 |
+
# This loads the model from MODEL_DIR and the embeddings from database_path
|
| 58 |
+
detector = Detector(
|
| 59 |
+
database_path=database_path,
|
| 60 |
+
model_name_or_path=MODEL_DIR,
|
| 61 |
+
device=device,
|
| 62 |
+
max_length=MAX_LEN,
|
| 63 |
+
pooling="max" # Default pooling
|
| 64 |
+
)
|
| 65 |
+
print(f"Text classification model (DETree) loaded successfully")
|
| 66 |
+
else:
|
| 67 |
+
if not Detector:
|
| 68 |
+
print("DETree detector could not be initialized due to missing package.")
|
| 69 |
+
if not MODEL_DIR:
|
| 70 |
+
print("DETree detector could not be initialized due to missing model directory.")
|
| 71 |
|
|
|
|
|
|
|
|
|
|
| 72 |
except Exception as e:
|
| 73 |
print(f"Error loading text model: {e}")
|
| 74 |
print("Text prediction will return fallback responses")
|
| 75 |
|
| 76 |
+
# ββ 3) Inference function ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 77 |
+
def predict_text(text: str, max_length: int = None):
|
| 78 |
+
"""
|
| 79 |
+
Predict whether the given text is human-written or AI-generated using DETree.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
text (str): The text to classify
|
| 83 |
+
max_length (int): Ignored in this implementation as DETree handles it globally,
|
| 84 |
+
but kept for compatibility.
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
dict: Contains predicted_class and confidence
|
| 88 |
+
"""
|
| 89 |
+
if detector is None:
|
| 90 |
return {"predicted_class": "Human", "confidence": -100.0}
|
| 91 |
+
|
|
|
|
|
|
|
|
|
|
| 92 |
try:
|
| 93 |
+
# detector.predict expects a list of strings
|
| 94 |
+
predictions = detector.predict([text])
|
| 95 |
+
if not predictions:
|
| 96 |
+
return {"predicted_class": "Human", "confidence": -100.0}
|
| 97 |
+
|
| 98 |
+
pred = predictions[0]
|
| 99 |
+
# pred.label is "Human" or "AI"
|
| 100 |
+
# Map to "Human" or "Ai" to match previous API
|
| 101 |
+
label = pred.label
|
| 102 |
+
if label == "AI":
|
| 103 |
+
label = "Ai"
|
| 104 |
+
|
| 105 |
+
# Confidence logic:
|
| 106 |
+
# If label is Human, use probability_human
|
| 107 |
+
# If label is Ai, use probability_ai
|
| 108 |
+
confidence = pred.probability_human if label == "Human" else pred.probability_ai
|
| 109 |
+
|
| 110 |
+
return {
|
| 111 |
+
"predicted_class": label,
|
| 112 |
+
"confidence": float(confidence)
|
| 113 |
+
}
|
| 114 |
except Exception as e:
|
| 115 |
print(f"Error during text prediction: {e}")
|
| 116 |
return {"predicted_class": "Human", "confidence": -100.0}
|
| 117 |
|
| 118 |
+
# ββ 4) Batch prediction ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 119 |
def predict_batch(texts, batch_size=16):
|
| 120 |
+
"""
|
| 121 |
+
Predict multiple texts in batches.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
texts (list): List of text strings to classify
|
| 125 |
+
batch_size (int): Batch size for processing
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
list: List of prediction dictionaries
|
| 129 |
+
"""
|
| 130 |
+
if detector is None:
|
| 131 |
return [{"predicted_class": "Human", "confidence": -100.0} for _ in texts]
|
| 132 |
+
|
| 133 |
+
# Temporarily update batch size if needed, or just use the detector's default
|
| 134 |
+
# We'll update it to respect the argument
|
| 135 |
+
original_batch_size = detector.batch_size
|
| 136 |
+
detector.batch_size = batch_size
|
| 137 |
+
|
| 138 |
+
try:
|
| 139 |
+
predictions = detector.predict(texts)
|
| 140 |
+
results = []
|
| 141 |
+
for text, pred in zip(texts, predictions):
|
| 142 |
+
label = pred.label
|
| 143 |
+
if label == "AI":
|
| 144 |
+
label = "Ai"
|
| 145 |
+
confidence = pred.probability_human if label == "Human" else pred.probability_ai
|
| 146 |
+
|
| 147 |
+
results.append({
|
| 148 |
+
"text": text,
|
| 149 |
+
"predicted_class": label,
|
| 150 |
+
"confidence": float(confidence)
|
| 151 |
+
})
|
| 152 |
+
return results
|
| 153 |
+
except Exception as e:
|
| 154 |
+
print(f"Error during batch prediction: {e}")
|
| 155 |
+
return [{"predicted_class": "Human", "confidence": -100.0} for _ in texts]
|
| 156 |
+
finally:
|
| 157 |
+
detector.batch_size = original_batch_size
|