AI_API / features /text_classifier /inferencer.py
Pujan-Dev's picture
update: updated the config and text_classifier
31fda96
from __future__ import annotations
from dataclasses import dataclass
from functools import lru_cache
import logging
import random
from typing import Any
import nltk
import numpy as np
from scipy.sparse import csr_matrix, hstack
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from features.text_classifier.model_loader import load_model
logger = logging.getLogger(__name__)
for resource in ("tokenizers/punkt", "tokenizers/punkt_tab"):
try:
nltk.data.find(resource)
except LookupError:
nltk.download(resource.split("/")[-1], quiet=True)
try:
import textstat
except ImportError:
textstat = None
@dataclass
class SentenceBlendConfig:
sentence_blend_weight: float = 0.70
sentence_to_doc_bias: float = 0.35
max_sentence_blend_weight: float = 0.90
max_sentence_to_doc_bias: float = 0.80
random_deviation_pct: float = 2.0
class PerplexityCalculator:
"""Lazy-loaded perplexity calculator for distilgpt2."""
def __init__(self, model_name: str = "distilgpt2"):
self.model_name = model_name
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self._tokenizer = None
self._model = None
def _load(self) -> None:
if self._model is not None and self._tokenizer is not None:
return
logger.info("Loading perplexity model: %s", self.model_name)
self._tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self._model = AutoModelForCausalLM.from_pretrained(self.model_name).to(self.device)
self._model.eval()
logger.info("Perplexity model loaded on %s", self.device)
def calculate(self, text: str, max_length: int = 512) -> float:
try:
self._load()
encodings = self._tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=max_length,
)
input_ids = encodings.input_ids.to(self.device)
with torch.no_grad():
outputs = self._model(input_ids, labels=input_ids)
loss = outputs.loss
perplexity = torch.exp(loss).item()
return min(float(perplexity), 10000.0)
except Exception as exc:
logger.warning("Perplexity fallback used due to error: %s", exc)
return 100.0
_perplexity_calc = PerplexityCalculator()
@lru_cache(maxsize=20000)
def _cached_perplexity(cleaned_text: str) -> float:
return _perplexity_calc.calculate(cleaned_text)
@lru_cache(maxsize=1)
def _get_model_artifacts() -> tuple[Any, Any, Any, Any, list[str], dict[str, Any]]:
return load_model()
def normalize_text(text: str) -> str:
return " ".join(str(text).split()).strip()
def split_into_sentences(text: str) -> list[str]:
cleaned = normalize_text(text)
if not cleaned:
return []
sentences = [s.strip() for s in nltk.sent_tokenize(cleaned) if s.strip()]
return sentences if sentences else [cleaned]
def extract_burstiness_features(text: str) -> dict[str, float]:
sentences = split_into_sentences(text)
if not sentences:
return {
"burst_mean": 0.0,
"burst_std": 0.0,
"burst_max": 0.0,
"burst_min": 0.0,
"burst_range": 0.0,
}
lengths = np.array([len(s.split()) for s in sentences], dtype=float)
return {
"burst_mean": float(np.mean(lengths)),
"burst_std": float(np.std(lengths)),
"burst_max": float(np.max(lengths)),
"burst_min": float(np.min(lengths)),
"burst_range": float(np.max(lengths) - np.min(lengths)),
}
def extract_stylometry_features(text: str) -> dict[str, float]:
words = text.split()
num_words = len(words)
num_chars = len(text)
num_sentences = max(len(split_into_sentences(text)), 1)
avg_word_len = float(np.mean([len(w) for w in words])) if words else 0.0
avg_sent_len = float(num_words / num_sentences)
unique_words = len(set(words))
lexical_diversity = float(unique_words / num_words) if num_words > 0 else 0.0
num_punct = sum(1 for c in text if c in ".,!?;:")
punct_ratio = float(num_punct / num_chars) if num_chars > 0 else 0.0
num_caps = sum(1 for c in text if c.isupper())
caps_ratio = float(num_caps / num_chars) if num_chars > 0 else 0.0
if textstat is not None:
try:
flesch_reading = float(textstat.flesch_reading_ease(text))
flesch_grade = float(textstat.flesch_kincaid_grade(text))
except Exception:
flesch_reading = 50.0
flesch_grade = 8.0
else:
flesch_reading = 50.0
flesch_grade = 8.0
return {
"num_words": float(num_words),
"num_chars": float(num_chars),
"num_sentences": float(num_sentences),
"avg_word_len": avg_word_len,
"avg_sent_len": avg_sent_len,
"lexical_diversity": lexical_diversity,
"punct_ratio": punct_ratio,
"caps_ratio": caps_ratio,
"flesch_reading": flesch_reading,
"flesch_grade": flesch_grade,
}
def extract_all_features(text: str, calc_perplexity: bool = True) -> dict[str, float]:
cleaned = normalize_text(text)
features: dict[str, float] = {}
if calc_perplexity:
features["perplexity"] = _cached_perplexity(cleaned)
else:
features["perplexity"] = 100.0
features.update(extract_burstiness_features(cleaned))
features.update(extract_stylometry_features(cleaned))
return features
def _predict_ai_probability(text: str) -> tuple[float, float]:
(
loaded_classifier,
loaded_scaler,
loaded_word_vectorizer,
loaded_char_vectorizer,
loaded_features,
loaded_metadata,
) = _get_model_artifacts()
calc_perplexity = bool(loaded_metadata.get("num_engineered_features", 0) > 0)
features = extract_all_features(text, calc_perplexity=calc_perplexity)
feature_vector = np.array([features[name] for name in loaded_features], dtype=float).reshape(1, -1)
feature_scaled = loaded_scaler.transform(feature_vector)
word_vec = loaded_word_vectorizer.transform([text])
char_vec = loaded_char_vectorizer.transform([text])
num_vec = csr_matrix(feature_scaled)
hybrid_vec = hstack([word_vec, char_vec, num_vec], format="csr")
if hasattr(loaded_classifier, "predict_proba"):
proba = loaded_classifier.predict_proba(hybrid_vec)[0]
ai_prob = float(proba[1])
else:
score = float(loaded_classifier.decision_function(hybrid_vec)[0])
ai_prob = float(1.0 / (1.0 + np.exp(-score)))
perplexity = float(features.get("perplexity", 100.0))
return ai_prob, perplexity
def classify_text(text: str) -> tuple[str, float, float]:
"""Return (label, perplexity, ai_likelihood_percent)."""
cleaned = normalize_text(text)
if not cleaned:
raise ValueError("Input text is empty")
ai_prob, perplexity = _predict_ai_probability(cleaned)
ai_likelihood = round(ai_prob * 100.0, 2)
label = "AI" if ai_likelihood >= 50.0 else "Human"
return label, perplexity, ai_likelihood
def analyze_text_with_sentences(
text: str,
) -> dict[str, Any]:
text = normalize_text(text)
overall_classification, overall_perplexity, overall_ai_likelihood = classify_text(text)
sentences = split_into_sentences(text)
if not sentences:
raise ValueError("Input text contains no valid sentences")
# do the per-sentence analysis
sentence_results = []
for sentence in sentences:
try:
label, perplexity, ai_likelihood = classify_text(sentence)
sentence_results.append(
{
"sentence": sentence,
"label": label,
"perplexity": perplexity,
"ai_likelihood": ai_likelihood,
}
)
except Exception as exc:
logger.warning("Error analyzing sentence: %s", exc)
sentence_results.append(
{
"sentence": sentence,
"label": "Error",
"perplexity": None,
"ai_likelihood": None,
}
)
return{
"sentences": sentence_results,
"summary": {
"overall": {
"label": overall_classification,
"perplexity": overall_perplexity,
"ai_likelihood": overall_ai_likelihood,
}
},
}