File size: 2,291 Bytes
28caba5 198ccb0 28caba5 198ccb0 28caba5 198ccb0 28caba5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 | """Model architectures for news classification.
Important: keep this package lightweight at import time.
Render/Uvicorn must import `api.main:app` before binding the port; importing the
training stack (pytorch-lightning/torchmetrics/matplotlib) here can delay startup
and cause Render port-scan timeouts.
"""
from importlib import import_module
from typing import Any
__all__ = [
"SimpleClassifier",
"CNNClassifier",
"RussianNewsClassifier",
"MultilingualBERTClassifier",
"RoBERTaNewsClassifier",
"DistilBERTNewsClassifier",
"MultiHeadAttentionClassifier",
"EnsembleClassifier",
"WeightedEnsemble",
"StackingEnsemble",
"VotingEnsemble",
"create_ensemble",
# training-only (kept for convenience, but lazily imported)
"NewsClassificationModule",
"TransformerClassificationModule",
]
_LAZY: dict[str, tuple[str, str]] = {
"SimpleClassifier": ("models.simple_classifier", "SimpleClassifier"),
"CNNClassifier": ("models.cnn_classifier", "CNNClassifier"),
"RussianNewsClassifier": ("models.transformer_model", "RussianNewsClassifier"),
"MultilingualBERTClassifier": ("models.transformer_model", "MultilingualBERTClassifier"),
"RoBERTaNewsClassifier": ("models.advanced_transformers", "RoBERTaNewsClassifier"),
"DistilBERTNewsClassifier": ("models.advanced_transformers", "DistilBERTNewsClassifier"),
"MultiHeadAttentionClassifier": ("models.advanced_transformers", "MultiHeadAttentionClassifier"),
"EnsembleClassifier": ("models.advanced_transformers", "EnsembleClassifier"),
"WeightedEnsemble": ("models.ensemble", "WeightedEnsemble"),
"StackingEnsemble": ("models.ensemble", "StackingEnsemble"),
"VotingEnsemble": ("models.ensemble", "VotingEnsemble"),
"create_ensemble": ("models.ensemble", "create_ensemble"),
# training-only
"NewsClassificationModule": ("models.lightning_module", "NewsClassificationModule"),
"TransformerClassificationModule": ("models.transformer_lightning", "TransformerClassificationModule"),
}
def __getattr__(name: str) -> Any:
if name not in _LAZY:
raise AttributeError(f"module 'models' has no attribute {name!r}")
module_name, attr_name = _LAZY[name]
mod = import_module(module_name)
return getattr(mod, attr_name)
|