Spaces:
Runtime error
Runtime error
| from typing import Dict, Any | |
| class ModelRegistry: | |
| def __init__(self): | |
| """ | |
| Stores registered models. | |
| Structure: | |
| { | |
| "model_name": { | |
| "model": model_object, | |
| "tokenizer": tokenizer_object, | |
| "metadata": {} | |
| } | |
| } | |
| """ | |
| self.models: Dict[str, Dict[str, Any]] = {} | |
| # --------------------------------------------------- | |
| # Register a new model | |
| # --------------------------------------------------- | |
| def register(self, name: str, model: Any, tokenizer: Any = None, **metadata): | |
| """ | |
| Register a model in the system. | |
| Args: | |
| name (str): unique name for the model | |
| model: the model object | |
| tokenizer: optional tokenizer | |
| metadata: any extra info (type, language, task) | |
| """ | |
| if name in self.models: | |
| print(f"[REGISTRY] Updating existing model: {name}") | |
| self.models[name] = { | |
| "model": model, | |
| "tokenizer": tokenizer, | |
| "metadata": metadata | |
| } | |
| print(f"[REGISTRY] Model registered → {name}") | |
| # --------------------------------------------------- | |
| # Get a model | |
| # --------------------------------------------------- | |
| def get(self, name: str): | |
| """ | |
| Retrieve a model by name. | |
| """ | |
| if name not in self.models: | |
| raise ValueError(f"Model '{name}' not found in registry") | |
| return self.models[name]["model"] | |
| # --------------------------------------------------- | |
| # Get tokenizer | |
| # --------------------------------------------------- | |
| def get_tokenizer(self, name: str): | |
| """ | |
| Retrieve tokenizer for a model. | |
| """ | |
| if name not in self.models: | |
| raise ValueError(f"Tokenizer for '{name}' not found") | |
| return self.models[name]["tokenizer"] | |
| # --------------------------------------------------- | |
| # Get metadata | |
| # --------------------------------------------------- | |
| def metadata(self, name: str): | |
| """ | |
| Retrieve model metadata. | |
| """ | |
| if name not in self.models: | |
| raise ValueError(f"Metadata for '{name}' not found") | |
| return self.models[name]["metadata"] | |
| # --------------------------------------------------- | |
| # List all models | |
| # --------------------------------------------------- | |
| def list_models(self): | |
| """ | |
| Returns list of registered models. | |
| """ | |
| return list(self.models.keys()) | |
| # --------------------------------------------------- | |
| # Remove model | |
| # --------------------------------------------------- | |
| def remove(self, name: str): | |
| """ | |
| Remove model from registry. | |
| """ | |
| if name in self.models: | |
| del self.models[name] | |
| print(f"[REGISTRY] Removed model → {name}") | |
| # --------------------------------------------------- | |
| # Check if model exists | |
| # --------------------------------------------------- | |
| def exists(self, name: str) -> bool: | |
| return name in self.models | |