File size: 5,195 Bytes
ac9ddbb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
"""Prediction helpers for different model types.
This module provides `ModelPredictor`, a lightweight wrapper that unifies
inference for SetFit, scikit-learn RandomForest pipelines, and HuggingFace
transformer sequence classification models. It standardizes inputs/outputs
to a NumPy array of shape (n_samples, n_labels).
"""
import os
from typing import List, Union
import joblib
import numpy as np
from setfit import SetFitModel
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
TextInput = Union[str, List[str]]
class ModelPredictor:
"""Unified predictor for SetFit, Random Forest and Transformer models.
Expected directory layout:
models/
βββ java/
β βββ setfit/ # SetFit saved model directory
β βββ random_forest.joblib # sklearn pipeline
β βββ transformer/ # HF model + tokenizer (config.json, etc.)
βββ python/
β βββ setfit/
β βββ random_forest.joblib
β βββ transformer/
βββ pharo/
βββ setfit/
βββ random_forest.joblib
βββ transformer/
"""
def __init__(
self,
lang: str,
model_type: str,
model_root: str = "models",
threshold: float = 0.5,
max_length: int = 128,
) -> None:
"""Parameters
----------
lang : str
One of {"java", "python", "pharo"}.
model_type : str
One of {"setfit", "random_forest", "transformer"}.
model_root : str
Root directory where models are stored.
threshold : float
Decision threshold for multi-label Transformer predictions.
Ignored for SetFit and Random Forest (they already output labels).
max_length : int
Max sequence length for Transformer tokenization.
"""
self.lang = lang
self.model_type = model_type
self.model_root = model_root
self.threshold = float(threshold)
self.max_length = int(max_length)
# device only matters for Transformer
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if model_type == "setfit":
model_path = os.path.join(self.model_root, self.lang, "setfit")
if not os.path.isdir(model_path):
raise FileNotFoundError(f"SetFit model not found at: {model_path}")
self.model = SetFitModel.from_pretrained(model_path)
elif model_type == "random_forest":
model_path = os.path.join(self.model_root, self.lang, "random_forest.joblib")
if not os.path.isfile(model_path):
raise FileNotFoundError(f"Random Forest model not found at: {model_path}")
self.model = joblib.load(model_path)
elif model_type == "transformer":
model_path = os.path.join(self.model_root, self.lang, "transformer")
if not os.path.isdir(model_path):
raise FileNotFoundError(f"Transformer model not found at: {model_path}")
# load tokenizer and model from the same directory used during training
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = AutoModelForSequenceClassification.from_pretrained(model_path).to(
self.device
)
self.model.eval()
else:
raise ValueError(f"Unsupported model_type: {model_type}")
def predict(self, texts: TextInput) -> np.ndarray:
"""Run prediction on one or many text samples.
Parameters
----------
texts : str | list[str]
A single text or a list of texts.
Returns
-------
np.ndarray
Array of shape (n_samples, n_labels) with integer (typically binary) values.
"""
if isinstance(texts, str):
texts = [texts]
if self.model_type == "setfit":
raw_outputs = self.model(texts)
outputs = np.array(list(raw_outputs), dtype=int)
elif self.model_type == "random_forest":
raw_outputs = self.model.predict(texts)
outputs = np.array(list(raw_outputs), dtype=int)
elif self.model_type == "transformer":
enc = self.tokenizer(
texts,
padding=True,
truncation=True,
max_length=self.max_length,
return_tensors="pt",
)
enc = {k: v.to(self.device) for k, v in enc.items()}
with torch.no_grad():
logits = self.model(**enc).logits
probs = torch.sigmoid(logits)
preds = (probs > self.threshold).long().cpu().numpy()
outputs = preds.astype(int)
else:
raise ValueError(f"Unsupported model_type: {self.model_type}")
# Ensure 2D shape (n_samples, n_labels)
if outputs.ndim == 1:
outputs = outputs.reshape(1, -1)
return outputs
|