Sky-Blue-da-ba-dee's picture
added files
ac9ddbb
"""Prediction helpers for different model types.
This module provides `ModelPredictor`, a lightweight wrapper that unifies
inference for SetFit, scikit-learn RandomForest pipelines, and HuggingFace
transformer sequence classification models. It standardizes inputs/outputs
to a NumPy array of shape (n_samples, n_labels).
"""
import os
from typing import List, Union
import joblib
import numpy as np
from setfit import SetFitModel
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
TextInput = Union[str, List[str]]
class ModelPredictor:
"""Unified predictor for SetFit, Random Forest and Transformer models.
Expected directory layout:
models/
β”œβ”€β”€ java/
β”‚ β”œβ”€β”€ setfit/ # SetFit saved model directory
β”‚ β”œβ”€β”€ random_forest.joblib # sklearn pipeline
β”‚ └── transformer/ # HF model + tokenizer (config.json, etc.)
β”œβ”€β”€ python/
β”‚ β”œβ”€β”€ setfit/
β”‚ β”œβ”€β”€ random_forest.joblib
β”‚ └── transformer/
└── pharo/
β”œβ”€β”€ setfit/
β”œβ”€β”€ random_forest.joblib
└── transformer/
"""
def __init__(
self,
lang: str,
model_type: str,
model_root: str = "models",
threshold: float = 0.5,
max_length: int = 128,
) -> None:
"""Parameters
----------
lang : str
One of {"java", "python", "pharo"}.
model_type : str
One of {"setfit", "random_forest", "transformer"}.
model_root : str
Root directory where models are stored.
threshold : float
Decision threshold for multi-label Transformer predictions.
Ignored for SetFit and Random Forest (they already output labels).
max_length : int
Max sequence length for Transformer tokenization.
"""
self.lang = lang
self.model_type = model_type
self.model_root = model_root
self.threshold = float(threshold)
self.max_length = int(max_length)
# device only matters for Transformer
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if model_type == "setfit":
model_path = os.path.join(self.model_root, self.lang, "setfit")
if not os.path.isdir(model_path):
raise FileNotFoundError(f"SetFit model not found at: {model_path}")
self.model = SetFitModel.from_pretrained(model_path)
elif model_type == "random_forest":
model_path = os.path.join(self.model_root, self.lang, "random_forest.joblib")
if not os.path.isfile(model_path):
raise FileNotFoundError(f"Random Forest model not found at: {model_path}")
self.model = joblib.load(model_path)
elif model_type == "transformer":
model_path = os.path.join(self.model_root, self.lang, "transformer")
if not os.path.isdir(model_path):
raise FileNotFoundError(f"Transformer model not found at: {model_path}")
# load tokenizer and model from the same directory used during training
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = AutoModelForSequenceClassification.from_pretrained(model_path).to(
self.device
)
self.model.eval()
else:
raise ValueError(f"Unsupported model_type: {model_type}")
def predict(self, texts: TextInput) -> np.ndarray:
"""Run prediction on one or many text samples.
Parameters
----------
texts : str | list[str]
A single text or a list of texts.
Returns
-------
np.ndarray
Array of shape (n_samples, n_labels) with integer (typically binary) values.
"""
if isinstance(texts, str):
texts = [texts]
if self.model_type == "setfit":
raw_outputs = self.model(texts)
outputs = np.array(list(raw_outputs), dtype=int)
elif self.model_type == "random_forest":
raw_outputs = self.model.predict(texts)
outputs = np.array(list(raw_outputs), dtype=int)
elif self.model_type == "transformer":
enc = self.tokenizer(
texts,
padding=True,
truncation=True,
max_length=self.max_length,
return_tensors="pt",
)
enc = {k: v.to(self.device) for k, v in enc.items()}
with torch.no_grad():
logits = self.model(**enc).logits
probs = torch.sigmoid(logits)
preds = (probs > self.threshold).long().cpu().numpy()
outputs = preds.astype(int)
else:
raise ValueError(f"Unsupported model_type: {self.model_type}")
# Ensure 2D shape (n_samples, n_labels)
if outputs.ndim == 1:
outputs = outputs.reshape(1, -1)
return outputs