Spaces:
Sleeping
Sleeping
File size: 6,996 Bytes
f871fed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
from typing import ClassVar, Dict, Optional, Union
from esperanto import (
AIFactory,
EmbeddingModel,
LanguageModel,
SpeechToTextModel,
TextToSpeechModel,
)
from loguru import logger
from open_notebook.database.repository import ensure_record_id, repo_query
from open_notebook.domain.base import ObjectModel, RecordModel
ModelType = Union[LanguageModel, EmbeddingModel, SpeechToTextModel, TextToSpeechModel]
class Model(ObjectModel):
table_name: ClassVar[str] = "model"
name: str
provider: str
type: str
@classmethod
async def get_models_by_type(cls, model_type):
models = await repo_query(
"SELECT * FROM model WHERE type=$model_type;", {"model_type": model_type}
)
return [Model(**model) for model in models]
class DefaultModels(RecordModel):
record_id: ClassVar[str] = "open_notebook:default_models"
default_chat_model: Optional[str] = None
default_transformation_model: Optional[str] = None
large_context_model: Optional[str] = None
default_text_to_speech_model: Optional[str] = None
default_speech_to_text_model: Optional[str] = None
# default_vision_model: Optional[str]
default_embedding_model: Optional[str] = None
default_tools_model: Optional[str] = None
@classmethod
async def get_instance(cls) -> "DefaultModels":
"""Always fetch fresh defaults from database (override parent caching behavior)"""
result = await repo_query(
"SELECT * FROM ONLY $record_id",
{"record_id": ensure_record_id(cls.record_id)},
)
if result:
if isinstance(result, list) and len(result) > 0:
data = result[0]
elif isinstance(result, dict):
data = result
else:
data = {}
else:
data = {}
# Create new instance with fresh data (bypass singleton cache)
instance = object.__new__(cls)
object.__setattr__(instance, "__dict__", {})
super(RecordModel, instance).__init__(**data)
return instance
class ModelManager:
def __init__(self):
pass # No caching needed
async def get_model(self, model_id: str, **kwargs) -> Optional[ModelType]:
"""Get a model by ID. Esperanto will cache the actual model instance."""
if not model_id:
return None
try:
model: Model = await Model.get(model_id)
except Exception:
raise ValueError(f"Model with ID {model_id} not found")
if not model.type or model.type not in [
"language",
"embedding",
"speech_to_text",
"text_to_speech",
]:
raise ValueError(f"Invalid model type: {model.type}")
# Create model based on type (Esperanto will cache the instance)
if model.type == "language":
return AIFactory.create_language(
model_name=model.name,
provider=model.provider,
config=kwargs,
)
elif model.type == "embedding":
return AIFactory.create_embedding(
model_name=model.name,
provider=model.provider,
config=kwargs,
)
elif model.type == "speech_to_text":
return AIFactory.create_speech_to_text(
model_name=model.name,
provider=model.provider,
config=kwargs,
)
elif model.type == "text_to_speech":
return AIFactory.create_text_to_speech(
model_name=model.name,
provider=model.provider,
config=kwargs,
)
else:
raise ValueError(f"Invalid model type: {model.type}")
async def get_defaults(self) -> DefaultModels:
"""Get the default models configuration from database"""
defaults = await DefaultModels.get_instance()
if not defaults:
raise RuntimeError("Failed to load default models configuration")
return defaults
async def get_speech_to_text(self, **kwargs) -> Optional[SpeechToTextModel]:
"""Get the default speech-to-text model"""
defaults = await self.get_defaults()
model_id = defaults.default_speech_to_text_model
if not model_id:
return None
model = await self.get_model(model_id, **kwargs)
assert model is None or isinstance(model, SpeechToTextModel), (
f"Expected SpeechToTextModel but got {type(model)}"
)
return model
async def get_text_to_speech(self, **kwargs) -> Optional[TextToSpeechModel]:
"""Get the default text-to-speech model"""
defaults = await self.get_defaults()
model_id = defaults.default_text_to_speech_model
if not model_id:
return None
model = await self.get_model(model_id, **kwargs)
assert model is None or isinstance(model, TextToSpeechModel), (
f"Expected TextToSpeechModel but got {type(model)}"
)
return model
async def get_embedding_model(self, **kwargs) -> Optional[EmbeddingModel]:
"""Get the default embedding model"""
defaults = await self.get_defaults()
model_id = defaults.default_embedding_model
if not model_id:
return None
model = await self.get_model(model_id, **kwargs)
assert model is None or isinstance(model, EmbeddingModel), (
f"Expected EmbeddingModel but got {type(model)}"
)
return model
async def get_default_model(self, model_type: str, **kwargs) -> Optional[ModelType]:
"""
Get the default model for a specific type.
Args:
model_type: The type of model to retrieve (e.g., 'chat', 'embedding', etc.)
**kwargs: Additional arguments to pass to the model constructor
"""
defaults = await self.get_defaults()
model_id = None
if model_type == "chat":
model_id = defaults.default_chat_model
elif model_type == "transformation":
model_id = (
defaults.default_transformation_model
or defaults.default_chat_model
)
elif model_type == "tools":
model_id = (
defaults.default_tools_model or defaults.default_chat_model
)
elif model_type == "embedding":
model_id = defaults.default_embedding_model
elif model_type == "text_to_speech":
model_id = defaults.default_text_to_speech_model
elif model_type == "speech_to_text":
model_id = defaults.default_speech_to_text_model
elif model_type == "large_context":
model_id = defaults.large_context_model
if not model_id:
return None
return await self.get_model(model_id, **kwargs)
model_manager = ModelManager()
|