|
|
"""Prediction helpers for different model types. |
|
|
|
|
|
This module provides `ModelPredictor`, a lightweight wrapper that unifies |
|
|
inference for SetFit, scikit-learn RandomForest pipelines, and HuggingFace |
|
|
transformer sequence classification models. It standardizes inputs/outputs |
|
|
to a NumPy array of shape (n_samples, n_labels). |
|
|
""" |
|
|
|
|
|
import os |
|
|
from typing import List, Union |
|
|
|
|
|
import joblib |
|
|
import numpy as np |
|
|
from setfit import SetFitModel |
|
|
import torch |
|
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
|
|
|
|
TextInput = Union[str, List[str]] |
|
|
|
|
|
|
|
|
class ModelPredictor: |
|
|
"""Unified predictor for SetFit, Random Forest and Transformer models. |
|
|
|
|
|
Expected directory layout: |
|
|
|
|
|
models/ |
|
|
βββ java/ |
|
|
β βββ setfit/ # SetFit saved model directory |
|
|
β βββ random_forest.joblib # sklearn pipeline |
|
|
β βββ transformer/ # HF model + tokenizer (config.json, etc.) |
|
|
βββ python/ |
|
|
β βββ setfit/ |
|
|
β βββ random_forest.joblib |
|
|
β βββ transformer/ |
|
|
βββ pharo/ |
|
|
βββ setfit/ |
|
|
βββ random_forest.joblib |
|
|
βββ transformer/ |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
lang: str, |
|
|
model_type: str, |
|
|
model_root: str = "models", |
|
|
threshold: float = 0.5, |
|
|
max_length: int = 128, |
|
|
) -> None: |
|
|
"""Parameters |
|
|
|
|
|
---------- |
|
|
lang : str |
|
|
One of {"java", "python", "pharo"}. |
|
|
model_type : str |
|
|
One of {"setfit", "random_forest", "transformer"}. |
|
|
model_root : str |
|
|
Root directory where models are stored. |
|
|
threshold : float |
|
|
Decision threshold for multi-label Transformer predictions. |
|
|
Ignored for SetFit and Random Forest (they already output labels). |
|
|
max_length : int |
|
|
Max sequence length for Transformer tokenization. |
|
|
|
|
|
""" |
|
|
self.lang = lang |
|
|
self.model_type = model_type |
|
|
self.model_root = model_root |
|
|
self.threshold = float(threshold) |
|
|
self.max_length = int(max_length) |
|
|
|
|
|
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
if model_type == "setfit": |
|
|
model_path = os.path.join(self.model_root, self.lang, "setfit") |
|
|
if not os.path.isdir(model_path): |
|
|
raise FileNotFoundError(f"SetFit model not found at: {model_path}") |
|
|
self.model = SetFitModel.from_pretrained(model_path) |
|
|
|
|
|
elif model_type == "random_forest": |
|
|
model_path = os.path.join(self.model_root, self.lang, "random_forest.joblib") |
|
|
if not os.path.isfile(model_path): |
|
|
raise FileNotFoundError(f"Random Forest model not found at: {model_path}") |
|
|
self.model = joblib.load(model_path) |
|
|
|
|
|
elif model_type == "transformer": |
|
|
model_path = os.path.join(self.model_root, self.lang, "transformer") |
|
|
if not os.path.isdir(model_path): |
|
|
raise FileNotFoundError(f"Transformer model not found at: {model_path}") |
|
|
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
self.model = AutoModelForSequenceClassification.from_pretrained(model_path).to( |
|
|
self.device |
|
|
) |
|
|
self.model.eval() |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Unsupported model_type: {model_type}") |
|
|
|
|
|
def predict(self, texts: TextInput) -> np.ndarray: |
|
|
"""Run prediction on one or many text samples. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
texts : str | list[str] |
|
|
A single text or a list of texts. |
|
|
|
|
|
Returns |
|
|
------- |
|
|
np.ndarray |
|
|
Array of shape (n_samples, n_labels) with integer (typically binary) values. |
|
|
|
|
|
""" |
|
|
if isinstance(texts, str): |
|
|
texts = [texts] |
|
|
|
|
|
if self.model_type == "setfit": |
|
|
raw_outputs = self.model(texts) |
|
|
outputs = np.array(list(raw_outputs), dtype=int) |
|
|
|
|
|
elif self.model_type == "random_forest": |
|
|
raw_outputs = self.model.predict(texts) |
|
|
outputs = np.array(list(raw_outputs), dtype=int) |
|
|
|
|
|
elif self.model_type == "transformer": |
|
|
enc = self.tokenizer( |
|
|
texts, |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=self.max_length, |
|
|
return_tensors="pt", |
|
|
) |
|
|
enc = {k: v.to(self.device) for k, v in enc.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = self.model(**enc).logits |
|
|
probs = torch.sigmoid(logits) |
|
|
preds = (probs > self.threshold).long().cpu().numpy() |
|
|
|
|
|
outputs = preds.astype(int) |
|
|
else: |
|
|
raise ValueError(f"Unsupported model_type: {self.model_type}") |
|
|
|
|
|
|
|
|
if outputs.ndim == 1: |
|
|
outputs = outputs.reshape(1, -1) |
|
|
|
|
|
return outputs |
|
|
|