fahmiaziz98 commited on
Commit
1c7e30d
·
1 Parent(s): 65e4649

[UPDATE]: add configuration model

Browse files
src/config/settings.py CHANGED
@@ -28,35 +28,35 @@ class Settings(BaseSettings):
28
  HOST: str = "0.0.0.0"
29
  PORT: int = 7860
30
  WORKERS: int = 1
31
- RELOAD: bool = False # Auto-reload on code changes (dev only)
32
 
33
  # Model Configuration
34
  MODEL_CONFIG_PATH: str = "src/config/models.yaml"
35
  MODEL_CACHE_DIR: str = "./model_cache"
36
- PRELOAD_MODELS: bool = True # Load all models at startup
37
 
38
  # Request Limits
39
- MAX_TEXT_LENGTH: int = 32000 # Maximum characters per text
40
- MAX_BATCH_SIZE: int = 100 # Maximum texts per batch request
41
- REQUEST_TIMEOUT: int = 30 # Request timeout in seconds
42
 
43
  # Logging
44
  LOG_LEVEL: str = "INFO" # DEBUG, INFO, WARNING, ERROR, CRITICAL
45
- LOG_FILE: bool = False # Write logs to file
46
  LOG_DIR: str = "logs"
47
 
48
- # CORS (if needed for web frontends)
49
  CORS_ENABLED: bool = False
50
  CORS_ORIGINS: list[str] = ["*"]
51
 
52
  # Model Settings
 
53
  TRUST_REMOTE_CODE: bool = True # For models requiring remote code
54
 
55
  model_config = SettingsConfigDict(
56
  env_file=".env",
57
  env_file_encoding="utf-8",
58
  case_sensitive=True,
59
- extra="ignore", # Ignore extra fields in .env
60
  )
61
 
62
  @property
@@ -86,10 +86,8 @@ class Settings(BaseSettings):
86
  f"Model configuration file not found: {self.MODEL_CONFIG_PATH}"
87
  )
88
 
89
- # Create cache directory if it doesn't exist
90
  Path(self.MODEL_CACHE_DIR).mkdir(parents=True, exist_ok=True)
91
 
92
- # Create log directory if logging to file
93
  if self.LOG_FILE:
94
  Path(self.LOG_DIR).mkdir(parents=True, exist_ok=True)
95
 
 
28
  HOST: str = "0.0.0.0"
29
  PORT: int = 7860
30
  WORKERS: int = 1
31
+ RELOAD: bool = False
32
 
33
  # Model Configuration
34
  MODEL_CONFIG_PATH: str = "src/config/models.yaml"
35
  MODEL_CACHE_DIR: str = "./model_cache"
36
+ PRELOAD_MODELS: bool = True
37
 
38
  # Request Limits
39
+ MAX_TEXT_LENGTH: int = 32000
40
+ MAX_BATCH_SIZE: int = 100
41
+ REQUEST_TIMEOUT: int = 30
42
 
43
  # Logging
44
  LOG_LEVEL: str = "INFO" # DEBUG, INFO, WARNING, ERROR, CRITICAL
45
+ LOG_FILE: bool = True # Write logs to file
46
  LOG_DIR: str = "logs"
47
 
 
48
  CORS_ENABLED: bool = False
49
  CORS_ORIGINS: list[str] = ["*"]
50
 
51
  # Model Settings
52
+ DEVICE: str = "cpu" # "cpu" or "cuda
53
  TRUST_REMOTE_CODE: bool = True # For models requiring remote code
54
 
55
  model_config = SettingsConfigDict(
56
  env_file=".env",
57
  env_file_encoding="utf-8",
58
  case_sensitive=True,
59
+ extra="ignore",
60
  )
61
 
62
  @property
 
86
  f"Model configuration file not found: {self.MODEL_CONFIG_PATH}"
87
  )
88
 
 
89
  Path(self.MODEL_CACHE_DIR).mkdir(parents=True, exist_ok=True)
90
 
 
91
  if self.LOG_FILE:
92
  Path(self.LOG_DIR).mkdir(parents=True, exist_ok=True)
93
 
src/models/embeddings/dense.py CHANGED
@@ -9,6 +9,7 @@ from typing import List, Optional
9
  from sentence_transformers import SentenceTransformer
10
  from loguru import logger
11
 
 
12
  from src.core.base import BaseEmbeddingModel
13
  from src.core.config import ModelConfig
14
  from src.core.exceptions import ModelLoadError, EmbeddingGenerationError
@@ -36,6 +37,7 @@ class DenseEmbeddingModel(BaseEmbeddingModel):
36
  """
37
  super().__init__(config)
38
  self.model: Optional[SentenceTransformer] = None
 
39
 
40
  def load(self) -> None:
41
  """
@@ -52,7 +54,9 @@ class DenseEmbeddingModel(BaseEmbeddingModel):
52
 
53
  try:
54
  self.model = SentenceTransformer(
55
- self.config.name, device="cpu", trust_remote_code=True
 
 
56
  )
57
  self._loaded = True
58
  logger.success(f"✓ Loaded dense model: {self.model_id}")
 
9
  from sentence_transformers import SentenceTransformer
10
  from loguru import logger
11
 
12
+ from src.config.settings import get_settings
13
  from src.core.base import BaseEmbeddingModel
14
  from src.core.config import ModelConfig
15
  from src.core.exceptions import ModelLoadError, EmbeddingGenerationError
 
37
  """
38
  super().__init__(config)
39
  self.model: Optional[SentenceTransformer] = None
40
+ self.settings = get_settings()
41
 
42
  def load(self) -> None:
43
  """
 
54
 
55
  try:
56
  self.model = SentenceTransformer(
57
+ self.config.name,
58
+ device=self.settings.DEVICE,
59
+ trust_remote_code=self.settings.TRUST_REMOTE_CODE
60
  )
61
  self._loaded = True
62
  logger.success(f"✓ Loaded dense model: {self.model_id}")
src/models/embeddings/sparse.py CHANGED
@@ -9,6 +9,7 @@ from typing import List, Optional, Dict, Any
9
  from sentence_transformers import SparseEncoder
10
  from loguru import logger
11
 
 
12
  from src.core.base import BaseEmbeddingModel
13
  from src.core.config import ModelConfig
14
  from src.core.exceptions import ModelLoadError, EmbeddingGenerationError
@@ -36,6 +37,7 @@ class SparseEmbeddingModel(BaseEmbeddingModel):
36
  """
37
  super().__init__(config)
38
  self.model: Optional[SparseEncoder] = None
 
39
 
40
  def load(self) -> None:
41
  """
@@ -51,7 +53,11 @@ class SparseEmbeddingModel(BaseEmbeddingModel):
51
  logger.info(f"Loading sparse embedding model: {self.config.name}")
52
 
53
  try:
54
- self.model = SparseEncoder(self.config.name)
 
 
 
 
55
  self._loaded = True
56
  logger.success(f"✓ Loaded sparse model: {self.model_id}")
57
 
 
9
  from sentence_transformers import SparseEncoder
10
  from loguru import logger
11
 
12
+ from src.config.settings import get_settings
13
  from src.core.base import BaseEmbeddingModel
14
  from src.core.config import ModelConfig
15
  from src.core.exceptions import ModelLoadError, EmbeddingGenerationError
 
37
  """
38
  super().__init__(config)
39
  self.model: Optional[SparseEncoder] = None
40
+ self.settings = get_settings()
41
 
42
  def load(self) -> None:
43
  """
 
53
  logger.info(f"Loading sparse embedding model: {self.config.name}")
54
 
55
  try:
56
+ self.model = SparseEncoder(
57
+ self.config.name,
58
+ device=self.settings.DEVICE,
59
+ trust_remote_code=self.settings.TRUST_REMOTE_CODE,
60
+ )
61
  self._loaded = True
62
  logger.success(f"✓ Loaded sparse model: {self.model_id}")
63