File size: 2,291 Bytes
28caba5
198ccb0
28caba5
 
 
 
 
 
 
 
198ccb0
 
 
 
 
 
 
 
 
 
 
 
 
 
28caba5
198ccb0
 
 
 
28caba5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
"""Model architectures for news classification.

Important: keep this package lightweight at import time.
Render/Uvicorn must import `api.main:app` before binding the port; importing the
training stack (pytorch-lightning/torchmetrics/matplotlib) here can delay startup
and cause Render port-scan timeouts.
"""

from importlib import import_module
from typing import Any

__all__ = [
    "SimpleClassifier",
    "CNNClassifier",
    "RussianNewsClassifier",
    "MultilingualBERTClassifier",
    "RoBERTaNewsClassifier",
    "DistilBERTNewsClassifier",
    "MultiHeadAttentionClassifier",
    "EnsembleClassifier",
    "WeightedEnsemble",
    "StackingEnsemble",
    "VotingEnsemble",
    "create_ensemble",
    # training-only (kept for convenience, but lazily imported)
    "NewsClassificationModule",
    "TransformerClassificationModule",
]

_LAZY: dict[str, tuple[str, str]] = {
    "SimpleClassifier": ("models.simple_classifier", "SimpleClassifier"),
    "CNNClassifier": ("models.cnn_classifier", "CNNClassifier"),
    "RussianNewsClassifier": ("models.transformer_model", "RussianNewsClassifier"),
    "MultilingualBERTClassifier": ("models.transformer_model", "MultilingualBERTClassifier"),
    "RoBERTaNewsClassifier": ("models.advanced_transformers", "RoBERTaNewsClassifier"),
    "DistilBERTNewsClassifier": ("models.advanced_transformers", "DistilBERTNewsClassifier"),
    "MultiHeadAttentionClassifier": ("models.advanced_transformers", "MultiHeadAttentionClassifier"),
    "EnsembleClassifier": ("models.advanced_transformers", "EnsembleClassifier"),
    "WeightedEnsemble": ("models.ensemble", "WeightedEnsemble"),
    "StackingEnsemble": ("models.ensemble", "StackingEnsemble"),
    "VotingEnsemble": ("models.ensemble", "VotingEnsemble"),
    "create_ensemble": ("models.ensemble", "create_ensemble"),
    # training-only
    "NewsClassificationModule": ("models.lightning_module", "NewsClassificationModule"),
    "TransformerClassificationModule": ("models.transformer_lightning", "TransformerClassificationModule"),
}


def __getattr__(name: str) -> Any:
    if name not in _LAZY:
        raise AttributeError(f"module 'models' has no attribute {name!r}")
    module_name, attr_name = _LAZY[name]
    mod = import_module(module_name)
    return getattr(mod, attr_name)