Spaces:
Running
Running
File size: 6,176 Bytes
13464bf 8eaaeee 21abb82 8eaaeee 21abb82 8eaaeee b4746b6 21abb82 8eaaeee b4746b6 8eaaeee b4746b6 8eaaeee 13464bf 8eaaeee 13464bf 8eaaeee 13464bf 8eaaeee b4746b6 8eaaeee b4746b6 13464bf 8eaaeee 5efd7a3 8eaaeee 13464bf 8eaaeee 13464bf 5efd7a3 13464bf 8eaaeee 13464bf 8eaaeee 5efd7a3 8eaaeee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import torch
import os
import sys
from pathlib import Path
from huggingface_hub import snapshot_download
# Ensure local detree package is importable
# This allows the script to find the 'detree' package if it sits in the same directory
current_dir = os.path.dirname(os.path.abspath(__file__))
if current_dir not in sys.path:
sys.path.append(current_dir)
try:
from detree.inference import Detector
except ImportError as e:
# Fallback if detree is not found (e.g. during initial setup check)
print(f"Warning: 'detree' package not found. Error: {e}")
Detector = None
# ββ 1) Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
REPO_ID = "MAS-AI-0000/Authentica"
TEXT_SUBFOLDER = "Lib/Models/Text" # where config.json/model.safetensors live in the repo
EMBEDDING_FILE = "priori1_center10k.pt"
MAX_LEN = 512
MODEL_DIR = None
try:
# download a local snapshot of just the Text folder and point MODEL_DIR at it
print(f"Downloading/Checking model from {REPO_ID}...")
_snapshot_dir = snapshot_download(
repo_id=REPO_ID,
allow_patterns=[f"{TEXT_SUBFOLDER}/*"]
)
MODEL_DIR = os.path.join(_snapshot_dir, TEXT_SUBFOLDER)
print(f"Model directory set to: {MODEL_DIR}")
except Exception as e:
print(f"Error downloading model from Hugging Face: {e}")
# ββ 2) Load model & tokenizer ββββββββββββββββββββββββββββββββββββββββββββββββββ
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Text prediction device: {device}")
detector = None
try:
if Detector and MODEL_DIR:
database_path = os.path.join(MODEL_DIR, EMBEDDING_FILE)
if not os.path.exists(MODEL_DIR):
print(f"Warning: Model directory not found at {MODEL_DIR}")
if not os.path.exists(database_path):
print(f"Warning: Embedding file not found at {database_path}")
print(f"Please ensure '{EMBEDDING_FILE}' is present in '{TEXT_SUBFOLDER}' of the Hugging Face repo.")
# Initialize DETree Detector
# This loads the model from MODEL_DIR and the embeddings from database_path
detector = Detector(
database_path=database_path,
model_name_or_path=MODEL_DIR,
device=device,
max_length=MAX_LEN,
pooling="max" # Default pooling
)
print(f"Text classification model (DETree) loaded successfully")
else:
if not Detector:
print("DETree detector could not be initialized due to missing package.")
if not MODEL_DIR:
print("DETree detector could not be initialized due to missing model directory.")
except Exception as e:
print(f"Error loading text model: {e}")
print("Text prediction will return fallback responses")
# ββ 3) Inference function ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def predict_text(text: str, max_length: int = None):
"""
Predict whether the given text is human-written or AI-generated using DETree.
Args:
text (str): The text to classify
max_length (int): Ignored in this implementation as DETree handles it globally,
but kept for compatibility.
Returns:
dict: Contains predicted_class and confidence
"""
if detector is None:
return {"predicted_class": "Human", "confidence": -100.0}
try:
# detector.predict expects a list of strings
predictions = detector.predict([text])
if not predictions:
return {"predicted_class": "Human", "confidence": -100.0}
pred = predictions[0]
# pred.label is "Human" or "AI"
# Map to "Human" or "Ai" to match previous API
label = pred.label
if label == "AI":
label = "Ai"
# Confidence logic:
# If label is Human, use probability_human
# If label is Ai, use probability_ai
confidence = pred.probability_human if label == "Human" else pred.probability_ai
return {
"predicted_class": label,
"confidence": float(confidence)
}
except Exception as e:
print(f"Error during text prediction: {e}")
return {"predicted_class": "Human", "confidence": -100.0}
# ββ 4) Batch prediction ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def predict_batch(texts, batch_size=16):
"""
Predict multiple texts in batches.
Args:
texts (list): List of text strings to classify
batch_size (int): Batch size for processing
Returns:
list: List of prediction dictionaries
"""
if detector is None:
return [{"predicted_class": "Human", "confidence": -100.0} for _ in texts]
# Temporarily update batch size if needed, or just use the detector's default
# We'll update it to respect the argument
original_batch_size = detector.batch_size
detector.batch_size = batch_size
try:
predictions = detector.predict(texts)
results = []
for text, pred in zip(texts, predictions):
label = pred.label
if label == "AI":
label = "Ai"
confidence = pred.probability_human if label == "Human" else pred.probability_ai
results.append({
"text": text,
"predicted_class": label,
"confidence": float(confidence)
})
return results
except Exception as e:
print(f"Error during batch prediction: {e}")
return [{"predicted_class": "Human", "confidence": -100.0} for _ in texts]
finally:
detector.batch_size = original_batch_size
|