| from typing import Dict, Any |
| import os |
| from traditional_classifier import TraditionalClassifier |
|
|
| try: |
| from modern_classifier import ModernClassifier |
| MODERN_MODELS_AVAILABLE = True |
| except ImportError: |
| MODERN_MODELS_AVAILABLE = False |
|
|
|
|
| class ModelManager: |
| """Manages different types of Arabic text classification models with per-request model selection and caching.""" |
| |
| AVAILABLE_MODELS = { |
| "traditional_svm": { |
| "type": "traditional", |
| "classifier_path": "models/traditional_svm_classifier.joblib", |
| "vectorizer_path": "models/traditional_tfidf_vectorizer_classifier.joblib", |
| "description": "Traditional SVM classifier with TF-IDF vectorization" |
| }, |
| |
| "modern_bert": { |
| "type": "modern", |
| "model_type": "bert", |
| "model_path": "models/modern_bert_classifier.safetensors", |
| "config_path": "config.json", |
| "description": "Modern BERT-based transformer classifier" |
| }, |
| |
| "modern_lstm": { |
| "type": "modern", |
| "model_type": "lstm", |
| "model_path": "models/modern_lstm_classifier.pth", |
| "description": "Modern LSTM-based neural network classifier" |
| } |
| } |
| |
| def __init__(self, default_model: str = "traditional_svm"): |
| self.default_model = default_model |
| self._model_cache = {} |
|
|
| def _get_model(self, model_name: str): |
| """Get model instance, loading from cache or creating new one.""" |
| if model_name not in self.AVAILABLE_MODELS: |
| raise ValueError(f"Model '{model_name}' not available. Available models: {list(self.AVAILABLE_MODELS.keys())}") |
| |
| if model_name in self._model_cache: |
| return self._model_cache[model_name] |
| |
| model_config = self.AVAILABLE_MODELS[model_name] |
| |
| if model_config["type"] == "traditional": |
| classifier_path = model_config["classifier_path"] |
| vectorizer_path = model_config["vectorizer_path"] |
| |
| if not os.path.exists(classifier_path): |
| raise FileNotFoundError(f"Classifier file not found: {classifier_path}") |
| if not os.path.exists(vectorizer_path): |
| raise FileNotFoundError(f"Vectorizer file not found: {vectorizer_path}") |
| |
| model = TraditionalClassifier(classifier_path, vectorizer_path) |
| |
| elif model_config["type"] == "modern": |
| if not MODERN_MODELS_AVAILABLE: |
| raise ImportError("Modern models require PyTorch and transformers") |
| |
| model_path = model_config["model_path"] |
| |
| if not os.path.exists(model_path): |
| raise FileNotFoundError(f"Model file not found: {model_path}") |
| |
| config_path = model_config.get("config_path") |
| if config_path and not os.path.exists(config_path): |
| config_path = None |
| |
| model = ModernClassifier( |
| model_type=model_config["model_type"], |
| model_path=model_path, |
| config_path=config_path |
| ) |
| |
| self._model_cache[model_name] = model |
| return model |
| |
| def predict(self, text: str, model_name: str = None) -> Dict[str, Any]: |
| """Predict using the specified model (or default if none specified).""" |
| if model_name is None: |
| model_name = self.default_model |
| |
| model = self._get_model(model_name) |
| result = model.predict(text) |
| |
| result["model_manager"] = { |
| "model_used": model_name, |
| "model_description": self.AVAILABLE_MODELS[model_name]["description"] |
| } |
| return result |
| |
| def predict_batch(self, texts: list, model_name: str = None) -> list: |
| """Predict batch using the specified model (or default if none specified).""" |
| if model_name is None: |
| model_name = self.default_model |
| |
| model = self._get_model(model_name) |
| results = model.predict_batch(texts) |
| |
| for result in results: |
| result["model_manager"] = { |
| "model_used": model_name, |
| "model_description": self.AVAILABLE_MODELS[model_name]["description"] |
| } |
| return results |
| |
| def get_model_info(self, model_name: str = None) -> Dict[str, Any]: |
| """Get information about a specific model (or default if none specified).""" |
| if model_name is None: |
| model_name = self.default_model |
| |
| model = self._get_model(model_name) |
| model_info = model.get_model_info() |
| model_info.update({ |
| "model_manager": { |
| "model_name": model_name, |
| "model_description": self.AVAILABLE_MODELS[model_name]["description"], |
| "model_config": self.AVAILABLE_MODELS[model_name], |
| "is_cached": model_name in self._model_cache |
| } |
| }) |
| return model_info |
| |
| def get_available_models(self) -> Dict[str, Any]: |
| """Get list of all available models.""" |
| available = {} |
| for model_name, config in self.AVAILABLE_MODELS.items(): |
| files_exist = True |
| missing_files = [] |
| |
| if config["type"] == "traditional": |
| for file_key in ["classifier_path", "vectorizer_path"]: |
| if not os.path.exists(config[file_key]): |
| files_exist = False |
| missing_files.append(config[file_key]) |
| elif config["type"] == "modern": |
| if not os.path.exists(config["model_path"]): |
| files_exist = False |
| missing_files.append(config["model_path"]) |
| |
| available[model_name] = { |
| "description": config["description"], |
| "type": config["type"], |
| "available": files_exist, |
| "missing_files": missing_files if not files_exist else [], |
| "is_default": model_name == self.default_model, |
| "is_cached": model_name in self._model_cache |
| } |
| |
| return available |
| |
| def clear_cache(self, model_name: str = None) -> Dict[str, Any]: |
| """Clear model cache (specific model or all models).""" |
| if model_name: |
| if model_name in self._model_cache: |
| del self._model_cache[model_name] |
| return {"message": f"Cache cleared for model: {model_name}"} |
| else: |
| return {"message": f"Model {model_name} was not cached"} |
| else: |
| cleared_count = len(self._model_cache) |
| self._model_cache.clear() |
| return {"message": f"Cache cleared for {cleared_count} models"} |
| |
| def get_cache_status(self) -> Dict[str, Any]: |
| """Get information about cached models.""" |
| return { |
| "cached_models": list(self._model_cache.keys()), |
| "cache_count": len(self._model_cache), |
| "default_model": self.default_model |
| } |
|
|