Spaces:
Sleeping
Sleeping
File size: 5,931 Bytes
4225666 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 | """
Model factory for creating LLM and embedding models.
Handles model switching and fallback logic.
"""
from typing import Optional
from pathlib import Path
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
from langchain_community.chat_models import ChatLlamaCpp
from app.core.config import settings
import logging
logger = logging.getLogger(__name__)
def get_embedding_model():
"""
Get the embedding model (currently only Gemini).
Returns:
GoogleGenerativeAIEmbeddings: Embedding model instance
"""
try:
embeddings = GoogleGenerativeAIEmbeddings(
model=settings.embedding_model_name,
google_api_key=settings.google_api_key
)
logger.info(f"Loaded embedding model: {settings.embedding_model_name}")
return embeddings
except Exception as e:
logger.error(f"Failed to load embedding model: {e}")
raise
def get_gemini_model():
"""
Get Google Gemini chat model.
Returns:
ChatGoogleGenerativeAI: Gemini model instance
"""
try:
model = ChatGoogleGenerativeAI(
model=settings.gemini_model_name,
google_api_key=settings.google_api_key,
)
logger.info(f"Loaded Gemini model: {settings.gemini_model_name}")
return model
except Exception as e:
logger.error(f"Failed to load Gemini model: {e}")
raise
def get_local_model():
"""
Get local Qwen model (LlamaCpp).
Returns:
ChatLlamaCpp: Local model instance
"""
try:
model_file = settings.model_path / settings.local_model_name
if not model_file.exists():
raise FileNotFoundError(
f"Model file not found: {model_file}\n"
f"Please download it to {settings.model_path}/"
)
# model = ChatLlamaCpp(
# model_path=str(model_file),
# n_ctx=4096, # Context window size
# n_batch=512, # Batch size for prompt processing
# n_threads=4, # Number of CPU threads
# max_tokens=settings.local_max_tokens, # Maximum tokens to generate
# temperature=0.05, # Low temperature for more focused responses
# top_p=0.8, # Nucleus sampling
# top_k=20, # Top-k sampling
# repeat_penalty=1.1, # Penalty for repetition
# f16_kv=True, # Use half-precision for KV cache
# verbose=False,
# )
model = ChatLlamaCpp(
model_path=str(model_file),
n_ctx=8096, # Small context to fit ~2GB total RAM usage [web:14]
n_batch=512, # Smaller batch for low memory throughput
n_threads=4, # Conservative threads (avoid RAM thrashing on 4GB) [web:12]
max_tokens= settings.local_max_tokens, # Short responses keep memory low
temperature=0.1, # Focused output, less randomness
top_p=0.9,
top_k=30,
repeat_penalty=1.05,
f16_kv=True, # Essential half-precision KV cache [web:14]
f16=True, # Full f16 where possible
verbose=True,
chat_format="chatml", # Proper templating
# Low-RAM must-haves:
numa=False, # Disable NUMA for single-CPU setups
use_mlock=False, # Skip memory locking (saves overhead)
use_mmap=True, # Memory-map model file (streams from disk)
)
# model = ChatLlamaCpp(
# model_path=str(model_file),
# n_ctx=4096, # Small context to fit ~2GB total RAM usage [web:14]
# n_batch=512, # Smaller batch for low memory throughput
# n_threads=4, # Conservative threads (avoid RAM thrashing on 4GB) [web:12]
# max_tokens= settings.local_max_tokens, # Short responses keep memory low
# temperature=0.1, # Focused output, less randomness
# top_p=0.9,
# min_p=0.15,
# top_k=30,
# repeat_penalty=1.05,
# f16_kv=True, # Essential half-precision KV cache [web:14]
# f16=True, # Full f16 where possible
# verbose=False,
# chat_format="qwen", # Proper templating,
# verbos=True
# )
logger.info(f"Loaded local model: {settings.local_model_name}")
return model
except Exception as e:
logger.error(f"Failed to load local model: {e}")
raise
def get_llm_model(provider: Optional[str] = None):
"""
Get LLM model based on configuration with fallback support.
Args:
provider: Override the default provider ("gemini" or "local")
If None, uses settings.llm_provider
Returns:
LLM model instance (Gemini or Local)
Raises:
RuntimeError: If all models fail to load
"""
provider = provider or settings.llm_provider
if provider == "gemini":
print("gemini loaded")
try:
return get_gemini_model()
except Exception as e:
logger.warning(f"Gemini model failed: {e}")
if settings.enable_fallback:
logger.info("Falling back to local model...")
return get_local_model()
raise
elif provider == "local":
print("local loaded")
try:
return get_local_model()
except Exception as e:
logger.warning(f"Local model failed: {e}")
if settings.enable_fallback:
logger.info("Falling back to Gemini model...")
return get_gemini_model()
raise
else:
raise ValueError(f"Unknown provider: {provider}. Use 'gemini' or 'local'")
|