| from typing import Dict, Any | |
| import os | |
| from traditional_classifier import TraditionalClassifier | |
| try: | |
| from modern_classifier import ModernClassifier | |
| MODERN_MODELS_AVAILABLE = True | |
| except ImportError: | |
| MODERN_MODELS_AVAILABLE = False | |
| class ModelManager: | |
| """Manages different types of Arabic text classification models with per-request model selection and caching.""" | |
| AVAILABLE_MODELS = { | |
| "traditional_svm": { | |
| "type": "traditional", | |
| "classifier_path": "models/traditional_svm_classifier.joblib", | |
| "vectorizer_path": "models/traditional_tfidf_vectorizer_classifier.joblib", | |
| "description": "Traditional SVM classifier with TF-IDF vectorization" | |
| }, | |
| "modern_bert": { | |
| "type": "modern", | |
| "model_type": "bert", | |
| "model_path": "models/modern_bert_classifier.safetensors", | |
| "config_path": "config.json", | |
| "description": "Modern BERT-based transformer classifier" | |
| }, | |
| "modern_lstm": { | |
| "type": "modern", | |
| "model_type": "lstm", | |
| "model_path": "models/modern_lstm_classifier.pth", | |
| "description": "Modern LSTM-based neural network classifier" | |
| } | |
| } | |
| def __init__(self, default_model: str = "traditional_svm"): | |
| self.default_model = default_model | |
| self._model_cache = {} | |
| def _get_model(self, model_name: str): | |
| """Get model instance, loading from cache or creating new one.""" | |
| if model_name not in self.AVAILABLE_MODELS: | |
| raise ValueError(f"Model '{model_name}' not available. Available models: {list(self.AVAILABLE_MODELS.keys())}") | |
| if model_name in self._model_cache: | |
| return self._model_cache[model_name] | |
| model_config = self.AVAILABLE_MODELS[model_name] | |
| if model_config["type"] == "traditional": | |
| classifier_path = model_config["classifier_path"] | |
| vectorizer_path = model_config["vectorizer_path"] | |
| if not os.path.exists(classifier_path): | |
| raise FileNotFoundError(f"Classifier file not found: {classifier_path}") | |
| if not os.path.exists(vectorizer_path): | |
| raise FileNotFoundError(f"Vectorizer file not found: {vectorizer_path}") | |
| model = TraditionalClassifier(classifier_path, vectorizer_path) | |
| elif model_config["type"] == "modern": | |
| if not MODERN_MODELS_AVAILABLE: | |
| raise ImportError("Modern models require PyTorch and transformers") | |
| model_path = model_config["model_path"] | |
| if not os.path.exists(model_path): | |
| raise FileNotFoundError(f"Model file not found: {model_path}") | |
| config_path = model_config.get("config_path") | |
| if config_path and not os.path.exists(config_path): | |
| config_path = None | |
| model = ModernClassifier( | |
| model_type=model_config["model_type"], | |
| model_path=model_path, | |
| config_path=config_path | |
| ) | |
| self._model_cache[model_name] = model | |
| return model | |
| def predict(self, text: str, model_name: str = None) -> Dict[str, Any]: | |
| """Predict using the specified model (or default if none specified).""" | |
| if model_name is None: | |
| model_name = self.default_model | |
| model = self._get_model(model_name) | |
| result = model.predict(text) | |
| result["model_manager"] = { | |
| "model_used": model_name, | |
| "model_description": self.AVAILABLE_MODELS[model_name]["description"] | |
| } | |
| return result | |
| def predict_batch(self, texts: list, model_name: str = None) -> list: | |
| """Predict batch using the specified model (or default if none specified).""" | |
| if model_name is None: | |
| model_name = self.default_model | |
| model = self._get_model(model_name) | |
| results = model.predict_batch(texts) | |
| for result in results: | |
| result["model_manager"] = { | |
| "model_used": model_name, | |
| "model_description": self.AVAILABLE_MODELS[model_name]["description"] | |
| } | |
| return results | |
| def get_model_info(self, model_name: str = None) -> Dict[str, Any]: | |
| """Get information about a specific model (or default if none specified).""" | |
| if model_name is None: | |
| model_name = self.default_model | |
| model = self._get_model(model_name) | |
| model_info = model.get_model_info() | |
| model_info.update({ | |
| "model_manager": { | |
| "model_name": model_name, | |
| "model_description": self.AVAILABLE_MODELS[model_name]["description"], | |
| "model_config": self.AVAILABLE_MODELS[model_name], | |
| "is_cached": model_name in self._model_cache | |
| } | |
| }) | |
| return model_info | |
| def get_available_models(self) -> Dict[str, Any]: | |
| """Get list of all available models.""" | |
| available = {} | |
| for model_name, config in self.AVAILABLE_MODELS.items(): | |
| files_exist = True | |
| missing_files = [] | |
| if config["type"] == "traditional": | |
| for file_key in ["classifier_path", "vectorizer_path"]: | |
| if not os.path.exists(config[file_key]): | |
| files_exist = False | |
| missing_files.append(config[file_key]) | |
| elif config["type"] == "modern": | |
| if not os.path.exists(config["model_path"]): | |
| files_exist = False | |
| missing_files.append(config["model_path"]) | |
| available[model_name] = { | |
| "description": config["description"], | |
| "type": config["type"], | |
| "available": files_exist, | |
| "missing_files": missing_files if not files_exist else [], | |
| "is_default": model_name == self.default_model, | |
| "is_cached": model_name in self._model_cache | |
| } | |
| return available | |
| def clear_cache(self, model_name: str = None) -> Dict[str, Any]: | |
| """Clear model cache (specific model or all models).""" | |
| if model_name: | |
| if model_name in self._model_cache: | |
| del self._model_cache[model_name] | |
| return {"message": f"Cache cleared for model: {model_name}"} | |
| else: | |
| return {"message": f"Model {model_name} was not cached"} | |
| else: | |
| cleared_count = len(self._model_cache) | |
| self._model_cache.clear() | |
| return {"message": f"Cache cleared for {cleared_count} models"} | |
| def get_cache_status(self) -> Dict[str, Any]: | |
| """Get information about cached models.""" | |
| return { | |
| "cached_models": list(self._model_cache.keys()), | |
| "cache_count": len(self._model_cache), | |
| "default_model": self.default_model | |
| } | |