File size: 4,718 Bytes
77a8ece 57b4170 77a8ece 57b4170 77a8ece |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import os
import pickle
import joblib
import torch
from transformers import PreTrainedModel
from .configuration_sm_subgroup_classifier import SmSubgroupClassifierConfig
class SmSubgroupClassifier(PreTrainedModel):
config_class = SmSubgroupClassifierConfig
def __init__(self, config):
super().__init__(config)
self.config = config
self._loaded_classifiers = {}
self.model_dir = None
self._available_models = None
@property
def available_models(self):
"""Auto-discover available models"""
if self._available_models is None:
self._available_models = self._discover_available_models()
return self._available_models
def _discover_available_models(self):
"""Scan model directory for available models"""
if not self.model_dir or not os.path.exists(self.model_dir):
return []
models = []
for item in os.listdir(self.model_dir):
item_path = os.path.join(self.model_dir, item)
if os.path.isdir(item_path):
# Verify it's a valid model directory
required_files = ["model.pkl", "metadata.pkl"]
if all(
os.path.exists(os.path.join(item_path, f)) for f in required_files
):
models.append(item)
return sorted(models)
def _load_classifier(self, model_key):
"""Load a specific classifier by model key (e.g., 'en_OP-ob')"""
if model_key in self._loaded_classifiers:
return self._loaded_classifiers[model_key]
if model_key not in self.available_models:
raise ValueError(
f"Model '{model_key}' not available. Available: {self.available_models}"
)
# Path to classifier
classifier_path = os.path.join(self.model_dir, model_key)
if not os.path.exists(classifier_path):
raise FileNotFoundError(f"Classifier not found at {classifier_path}")
# Load components
classifier = joblib.load(os.path.join(classifier_path, "model.pkl"))
with open(os.path.join(classifier_path, "metadata.pkl"), "rb") as f:
metadata = pickle.load(f)
classifier_info = {
"classifier": classifier,
"class_names": metadata["class_names"],
}
self._loaded_classifiers[model_key] = classifier_info
return classifier_info
def forward(self, language, model_name, embeddings):
"""
Args:
language: Language code (en, fi, sv)
model_name: Model name (OP-ob, NA, etc.)
embeddings: Pre-computed embeddings
"""
# Create model key
model_key = f"{language}_{model_name}"
# Convert embeddings to numpy if needed
if torch.is_tensor(embeddings):
embeddings = embeddings.detach().cpu().numpy()
if embeddings.ndim == 1:
embeddings = embeddings.reshape(1, -1)
# Load classifier
classifier_info = self._load_classifier(model_key)
# Predict directly without scaling
predictions = classifier_info["classifier"].predict(embeddings)
probabilities = classifier_info["classifier"].predict_proba(embeddings)
# Format results - just use class names and probabilities
results = []
for pred, probs in zip(predictions, probabilities):
predicted_class_name = classifier_info["class_names"][pred]
# Get all class probabilities
all_probs = {
classifier_info["class_names"][i]: float(prob)
for i, prob in enumerate(probs)
}
results.append(
{
"label": predicted_class_name,
"probabilities": all_probs,
}
)
return results[0] if len(results) == 1 else results
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
# Load config
config = SmSubgroupClassifierConfig.from_pretrained(
pretrained_model_name_or_path, **kwargs
)
# Create model instance (skip the pytorch weight loading)
model = cls(config)
# For HF Hub, we need to resolve to the actual cached directory
try:
from huggingface_hub import snapshot_download
# Download/get the cached directory path
model.model_dir = snapshot_download(pretrained_model_name_or_path)
except ImportError:
# Fallback if huggingface_hub not available
model.model_dir = pretrained_model_name_or_path
return model
|