File size: 4,718 Bytes
77a8ece
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57b4170
77a8ece
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57b4170
 
 
77a8ece
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os
import pickle

import joblib
import torch
from transformers import PreTrainedModel

from .configuration_sm_subgroup_classifier import SmSubgroupClassifierConfig


class SmSubgroupClassifier(PreTrainedModel):
    config_class = SmSubgroupClassifierConfig

    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self._loaded_classifiers = {}
        self.model_dir = None
        self._available_models = None

    @property
    def available_models(self):
        """Auto-discover available models"""
        if self._available_models is None:
            self._available_models = self._discover_available_models()
        return self._available_models

    def _discover_available_models(self):
        """Scan model directory for available models"""
        if not self.model_dir or not os.path.exists(self.model_dir):
            return []

        models = []
        for item in os.listdir(self.model_dir):
            item_path = os.path.join(self.model_dir, item)
            if os.path.isdir(item_path):
                # Verify it's a valid model directory
                required_files = ["model.pkl", "metadata.pkl"]
                if all(
                    os.path.exists(os.path.join(item_path, f)) for f in required_files
                ):
                    models.append(item)

        return sorted(models)

    def _load_classifier(self, model_key):
        """Load a specific classifier by model key (e.g., 'en_OP-ob')"""
        if model_key in self._loaded_classifiers:
            return self._loaded_classifiers[model_key]

        if model_key not in self.available_models:
            raise ValueError(
                f"Model '{model_key}' not available. Available: {self.available_models}"
            )

        # Path to classifier
        classifier_path = os.path.join(self.model_dir, model_key)
        if not os.path.exists(classifier_path):
            raise FileNotFoundError(f"Classifier not found at {classifier_path}")

        # Load components
        classifier = joblib.load(os.path.join(classifier_path, "model.pkl"))
        with open(os.path.join(classifier_path, "metadata.pkl"), "rb") as f:
            metadata = pickle.load(f)

        classifier_info = {
            "classifier": classifier,
            "class_names": metadata["class_names"],
        }

        self._loaded_classifiers[model_key] = classifier_info
        return classifier_info

    def forward(self, language, model_name, embeddings):
        """
        Args:
            language: Language code (en, fi, sv)
            model_name: Model name (OP-ob, NA, etc.)
            embeddings: Pre-computed embeddings
        """
        # Create model key
        model_key = f"{language}_{model_name}"

        # Convert embeddings to numpy if needed
        if torch.is_tensor(embeddings):
            embeddings = embeddings.detach().cpu().numpy()
        if embeddings.ndim == 1:
            embeddings = embeddings.reshape(1, -1)

        # Load classifier
        classifier_info = self._load_classifier(model_key)

        # Predict directly without scaling
        predictions = classifier_info["classifier"].predict(embeddings)
        probabilities = classifier_info["classifier"].predict_proba(embeddings)

        # Format results - just use class names and probabilities
        results = []
        for pred, probs in zip(predictions, probabilities):
            predicted_class_name = classifier_info["class_names"][pred]
            # Get all class probabilities
            all_probs = {
                classifier_info["class_names"][i]: float(prob)
                for i, prob in enumerate(probs)
            }
            results.append(
                {
                    "label": predicted_class_name,
                    "probabilities": all_probs,
                }
            )

        return results[0] if len(results) == 1 else results

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
        # Load config
        config = SmSubgroupClassifierConfig.from_pretrained(
            pretrained_model_name_or_path, **kwargs
        )

        # Create model instance (skip the pytorch weight loading)
        model = cls(config)

        # For HF Hub, we need to resolve to the actual cached directory
        try:
            from huggingface_hub import snapshot_download

            # Download/get the cached directory path
            model.model_dir = snapshot_download(pretrained_model_name_or_path)
        except ImportError:
            # Fallback if huggingface_hub not available
            model.model_dir = pretrained_model_name_or_path

        return model