import torch import os import sys from pathlib import Path from huggingface_hub import snapshot_download # Ensure local detree package is importable # This allows the script to find the 'detree' package if it sits in the same directory current_dir = os.path.dirname(os.path.abspath(__file__)) if current_dir not in sys.path: sys.path.append(current_dir) try: from detree.inference import Detector except ImportError as e: # Fallback if detree is not found (e.g. during initial setup check) print(f"Warning: 'detree' package not found. Error: {e}") Detector = None # ── 1) Configuration ──────────────────────────────────────────────────────────── REPO_ID = "MAS-AI-0000/Authentica" TEXT_SUBFOLDER = "Lib/Models/Text" # where config.json/model.safetensors live in the repo EMBEDDING_FILE = "priori1_center10k.pt" MAX_LEN = 512 MODEL_DIR = None try: # download a local snapshot of just the Text folder and point MODEL_DIR at it print(f"Downloading/Checking model from {REPO_ID}...") _snapshot_dir = snapshot_download( repo_id=REPO_ID, allow_patterns=[f"{TEXT_SUBFOLDER}/*"] ) MODEL_DIR = os.path.join(_snapshot_dir, TEXT_SUBFOLDER) print(f"Model directory set to: {MODEL_DIR}") except Exception as e: print(f"Error downloading model from Hugging Face: {e}") # ── 2) Load model & tokenizer ────────────────────────────────────────────────── device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Text prediction device: {device}") detector = None try: if Detector and MODEL_DIR: database_path = os.path.join(MODEL_DIR, EMBEDDING_FILE) if not os.path.exists(MODEL_DIR): print(f"Warning: Model directory not found at {MODEL_DIR}") if not os.path.exists(database_path): print(f"Warning: Embedding file not found at {database_path}") print(f"Please ensure '{EMBEDDING_FILE}' is present in '{TEXT_SUBFOLDER}' of the Hugging Face repo.") # Initialize DETree Detector # This loads the model from MODEL_DIR and the embeddings from database_path detector = Detector( database_path=database_path, model_name_or_path=MODEL_DIR, device=device, max_length=MAX_LEN, pooling="max" # Default pooling ) print(f"Text classification model (DETree) loaded successfully") else: if not Detector: print("DETree detector could not be initialized due to missing package.") if not MODEL_DIR: print("DETree detector could not be initialized due to missing model directory.") except Exception as e: print(f"Error loading text model: {e}") print("Text prediction will return fallback responses") # ── 3) Inference function ────────────────────────────────────────────────────── def predict_text(text: str, max_length: int = None): """ Predict whether the given text is human-written or AI-generated using DETree. Args: text (str): The text to classify max_length (int): Ignored in this implementation as DETree handles it globally, but kept for compatibility. Returns: dict: Contains predicted_class and confidence """ if detector is None: return {"predicted_class": "Human", "confidence": -100.0} try: # detector.predict expects a list of strings predictions = detector.predict([text]) if not predictions: return {"predicted_class": "Human", "confidence": -100.0} pred = predictions[0] # pred.label is "Human" or "AI" # Map to "Human" or "Ai" to match previous API # Determine label based on higher confidence if pred.probability_ai > pred.probability_human: label = "AI" confidence = pred.probability_ai else: label = "Human" confidence = pred.probability_human # Confidence logic: # If label is Human, use probability_human # If label is Ai, use probability_ai confidence = pred.probability_human if label == "Human" else pred.probability_ai return { "predicted_class": label, "confidence": float(confidence) } except Exception as e: print(f"Error during text prediction: {e}") return {"predicted_class": "Human", "confidence": -100.0} # ── 4) Batch prediction ──────────────────────────────────────────────────────── def predict_batch(texts, batch_size=16): """ Predict multiple texts in batches. Args: texts (list): List of text strings to classify batch_size (int): Batch size for processing Returns: list: List of prediction dictionaries """ if detector is None: return [{"predicted_class": "Human", "confidence": -100.0} for _ in texts] # Temporarily update batch size if needed, or just use the detector's default # We'll update it to respect the argument original_batch_size = detector.batch_size detector.batch_size = batch_size try: predictions = detector.predict(texts) results = [] for text, pred in zip(texts, predictions): label = pred.label # Determine label based on higher confidence if pred.probability_ai > pred.probability_human: label = "AI" confidence = pred.probability_ai else: label = "Human" confidence = pred.probability_human results.append({ "text": text, "predicted_class": label, "confidence": float(confidence) }) return results except Exception as e: print(f"Error during batch prediction: {e}") return [{"predicted_class": "Human", "confidence": -100.0} for _ in texts] finally: detector.batch_size = original_batch_size