Spaces:
Running
Running
fahmiaziz98
commited on
Commit
·
1c7e30d
1
Parent(s):
65e4649
[UPDATE]: add configuration model
Browse files- src/config/settings.py +8 -10
- src/models/embeddings/dense.py +5 -1
- src/models/embeddings/sparse.py +7 -1
src/config/settings.py
CHANGED
|
@@ -28,35 +28,35 @@ class Settings(BaseSettings):
|
|
| 28 |
HOST: str = "0.0.0.0"
|
| 29 |
PORT: int = 7860
|
| 30 |
WORKERS: int = 1
|
| 31 |
-
RELOAD: bool = False
|
| 32 |
|
| 33 |
# Model Configuration
|
| 34 |
MODEL_CONFIG_PATH: str = "src/config/models.yaml"
|
| 35 |
MODEL_CACHE_DIR: str = "./model_cache"
|
| 36 |
-
PRELOAD_MODELS: bool = True
|
| 37 |
|
| 38 |
# Request Limits
|
| 39 |
-
MAX_TEXT_LENGTH: int = 32000
|
| 40 |
-
MAX_BATCH_SIZE: int = 100
|
| 41 |
-
REQUEST_TIMEOUT: int = 30
|
| 42 |
|
| 43 |
# Logging
|
| 44 |
LOG_LEVEL: str = "INFO" # DEBUG, INFO, WARNING, ERROR, CRITICAL
|
| 45 |
-
LOG_FILE: bool =
|
| 46 |
LOG_DIR: str = "logs"
|
| 47 |
|
| 48 |
-
# CORS (if needed for web frontends)
|
| 49 |
CORS_ENABLED: bool = False
|
| 50 |
CORS_ORIGINS: list[str] = ["*"]
|
| 51 |
|
| 52 |
# Model Settings
|
|
|
|
| 53 |
TRUST_REMOTE_CODE: bool = True # For models requiring remote code
|
| 54 |
|
| 55 |
model_config = SettingsConfigDict(
|
| 56 |
env_file=".env",
|
| 57 |
env_file_encoding="utf-8",
|
| 58 |
case_sensitive=True,
|
| 59 |
-
extra="ignore",
|
| 60 |
)
|
| 61 |
|
| 62 |
@property
|
|
@@ -86,10 +86,8 @@ class Settings(BaseSettings):
|
|
| 86 |
f"Model configuration file not found: {self.MODEL_CONFIG_PATH}"
|
| 87 |
)
|
| 88 |
|
| 89 |
-
# Create cache directory if it doesn't exist
|
| 90 |
Path(self.MODEL_CACHE_DIR).mkdir(parents=True, exist_ok=True)
|
| 91 |
|
| 92 |
-
# Create log directory if logging to file
|
| 93 |
if self.LOG_FILE:
|
| 94 |
Path(self.LOG_DIR).mkdir(parents=True, exist_ok=True)
|
| 95 |
|
|
|
|
| 28 |
HOST: str = "0.0.0.0"
|
| 29 |
PORT: int = 7860
|
| 30 |
WORKERS: int = 1
|
| 31 |
+
RELOAD: bool = False
|
| 32 |
|
| 33 |
# Model Configuration
|
| 34 |
MODEL_CONFIG_PATH: str = "src/config/models.yaml"
|
| 35 |
MODEL_CACHE_DIR: str = "./model_cache"
|
| 36 |
+
PRELOAD_MODELS: bool = True
|
| 37 |
|
| 38 |
# Request Limits
|
| 39 |
+
MAX_TEXT_LENGTH: int = 32000
|
| 40 |
+
MAX_BATCH_SIZE: int = 100
|
| 41 |
+
REQUEST_TIMEOUT: int = 30
|
| 42 |
|
| 43 |
# Logging
|
| 44 |
LOG_LEVEL: str = "INFO" # DEBUG, INFO, WARNING, ERROR, CRITICAL
|
| 45 |
+
LOG_FILE: bool = True # Write logs to file
|
| 46 |
LOG_DIR: str = "logs"
|
| 47 |
|
|
|
|
| 48 |
CORS_ENABLED: bool = False
|
| 49 |
CORS_ORIGINS: list[str] = ["*"]
|
| 50 |
|
| 51 |
# Model Settings
|
| 52 |
+
DEVICE: str = "cpu" # "cpu" or "cuda
|
| 53 |
TRUST_REMOTE_CODE: bool = True # For models requiring remote code
|
| 54 |
|
| 55 |
model_config = SettingsConfigDict(
|
| 56 |
env_file=".env",
|
| 57 |
env_file_encoding="utf-8",
|
| 58 |
case_sensitive=True,
|
| 59 |
+
extra="ignore",
|
| 60 |
)
|
| 61 |
|
| 62 |
@property
|
|
|
|
| 86 |
f"Model configuration file not found: {self.MODEL_CONFIG_PATH}"
|
| 87 |
)
|
| 88 |
|
|
|
|
| 89 |
Path(self.MODEL_CACHE_DIR).mkdir(parents=True, exist_ok=True)
|
| 90 |
|
|
|
|
| 91 |
if self.LOG_FILE:
|
| 92 |
Path(self.LOG_DIR).mkdir(parents=True, exist_ok=True)
|
| 93 |
|
src/models/embeddings/dense.py
CHANGED
|
@@ -9,6 +9,7 @@ from typing import List, Optional
|
|
| 9 |
from sentence_transformers import SentenceTransformer
|
| 10 |
from loguru import logger
|
| 11 |
|
|
|
|
| 12 |
from src.core.base import BaseEmbeddingModel
|
| 13 |
from src.core.config import ModelConfig
|
| 14 |
from src.core.exceptions import ModelLoadError, EmbeddingGenerationError
|
|
@@ -36,6 +37,7 @@ class DenseEmbeddingModel(BaseEmbeddingModel):
|
|
| 36 |
"""
|
| 37 |
super().__init__(config)
|
| 38 |
self.model: Optional[SentenceTransformer] = None
|
|
|
|
| 39 |
|
| 40 |
def load(self) -> None:
|
| 41 |
"""
|
|
@@ -52,7 +54,9 @@ class DenseEmbeddingModel(BaseEmbeddingModel):
|
|
| 52 |
|
| 53 |
try:
|
| 54 |
self.model = SentenceTransformer(
|
| 55 |
-
self.config.name,
|
|
|
|
|
|
|
| 56 |
)
|
| 57 |
self._loaded = True
|
| 58 |
logger.success(f"✓ Loaded dense model: {self.model_id}")
|
|
|
|
| 9 |
from sentence_transformers import SentenceTransformer
|
| 10 |
from loguru import logger
|
| 11 |
|
| 12 |
+
from src.config.settings import get_settings
|
| 13 |
from src.core.base import BaseEmbeddingModel
|
| 14 |
from src.core.config import ModelConfig
|
| 15 |
from src.core.exceptions import ModelLoadError, EmbeddingGenerationError
|
|
|
|
| 37 |
"""
|
| 38 |
super().__init__(config)
|
| 39 |
self.model: Optional[SentenceTransformer] = None
|
| 40 |
+
self.settings = get_settings()
|
| 41 |
|
| 42 |
def load(self) -> None:
|
| 43 |
"""
|
|
|
|
| 54 |
|
| 55 |
try:
|
| 56 |
self.model = SentenceTransformer(
|
| 57 |
+
self.config.name,
|
| 58 |
+
device=self.settings.DEVICE,
|
| 59 |
+
trust_remote_code=self.settings.TRUST_REMOTE_CODE
|
| 60 |
)
|
| 61 |
self._loaded = True
|
| 62 |
logger.success(f"✓ Loaded dense model: {self.model_id}")
|
src/models/embeddings/sparse.py
CHANGED
|
@@ -9,6 +9,7 @@ from typing import List, Optional, Dict, Any
|
|
| 9 |
from sentence_transformers import SparseEncoder
|
| 10 |
from loguru import logger
|
| 11 |
|
|
|
|
| 12 |
from src.core.base import BaseEmbeddingModel
|
| 13 |
from src.core.config import ModelConfig
|
| 14 |
from src.core.exceptions import ModelLoadError, EmbeddingGenerationError
|
|
@@ -36,6 +37,7 @@ class SparseEmbeddingModel(BaseEmbeddingModel):
|
|
| 36 |
"""
|
| 37 |
super().__init__(config)
|
| 38 |
self.model: Optional[SparseEncoder] = None
|
|
|
|
| 39 |
|
| 40 |
def load(self) -> None:
|
| 41 |
"""
|
|
@@ -51,7 +53,11 @@ class SparseEmbeddingModel(BaseEmbeddingModel):
|
|
| 51 |
logger.info(f"Loading sparse embedding model: {self.config.name}")
|
| 52 |
|
| 53 |
try:
|
| 54 |
-
self.model = SparseEncoder(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
self._loaded = True
|
| 56 |
logger.success(f"✓ Loaded sparse model: {self.model_id}")
|
| 57 |
|
|
|
|
| 9 |
from sentence_transformers import SparseEncoder
|
| 10 |
from loguru import logger
|
| 11 |
|
| 12 |
+
from src.config.settings import get_settings
|
| 13 |
from src.core.base import BaseEmbeddingModel
|
| 14 |
from src.core.config import ModelConfig
|
| 15 |
from src.core.exceptions import ModelLoadError, EmbeddingGenerationError
|
|
|
|
| 37 |
"""
|
| 38 |
super().__init__(config)
|
| 39 |
self.model: Optional[SparseEncoder] = None
|
| 40 |
+
self.settings = get_settings()
|
| 41 |
|
| 42 |
def load(self) -> None:
|
| 43 |
"""
|
|
|
|
| 53 |
logger.info(f"Loading sparse embedding model: {self.config.name}")
|
| 54 |
|
| 55 |
try:
|
| 56 |
+
self.model = SparseEncoder(
|
| 57 |
+
self.config.name,
|
| 58 |
+
device=self.settings.DEVICE,
|
| 59 |
+
trust_remote_code=self.settings.TRUST_REMOTE_CODE,
|
| 60 |
+
)
|
| 61 |
self._loaded = True
|
| 62 |
logger.success(f"✓ Loaded sparse model: {self.model_id}")
|
| 63 |
|