vgecbot / app /utils /model_factory.py
harsh-dev's picture
docker deployment
4225666
"""
Model factory for creating LLM and embedding models.
Handles model switching and fallback logic.
"""
from typing import Optional
from pathlib import Path
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
from langchain_community.chat_models import ChatLlamaCpp
from app.core.config import settings
import logging
logger = logging.getLogger(__name__)
def get_embedding_model():
"""
Get the embedding model (currently only Gemini).
Returns:
GoogleGenerativeAIEmbeddings: Embedding model instance
"""
try:
embeddings = GoogleGenerativeAIEmbeddings(
model=settings.embedding_model_name,
google_api_key=settings.google_api_key
)
logger.info(f"Loaded embedding model: {settings.embedding_model_name}")
return embeddings
except Exception as e:
logger.error(f"Failed to load embedding model: {e}")
raise
def get_gemini_model():
"""
Get Google Gemini chat model.
Returns:
ChatGoogleGenerativeAI: Gemini model instance
"""
try:
model = ChatGoogleGenerativeAI(
model=settings.gemini_model_name,
google_api_key=settings.google_api_key,
)
logger.info(f"Loaded Gemini model: {settings.gemini_model_name}")
return model
except Exception as e:
logger.error(f"Failed to load Gemini model: {e}")
raise
def get_local_model():
"""
Get local Qwen model (LlamaCpp).
Returns:
ChatLlamaCpp: Local model instance
"""
try:
model_file = settings.model_path / settings.local_model_name
if not model_file.exists():
raise FileNotFoundError(
f"Model file not found: {model_file}\n"
f"Please download it to {settings.model_path}/"
)
# model = ChatLlamaCpp(
# model_path=str(model_file),
# n_ctx=4096, # Context window size
# n_batch=512, # Batch size for prompt processing
# n_threads=4, # Number of CPU threads
# max_tokens=settings.local_max_tokens, # Maximum tokens to generate
# temperature=0.05, # Low temperature for more focused responses
# top_p=0.8, # Nucleus sampling
# top_k=20, # Top-k sampling
# repeat_penalty=1.1, # Penalty for repetition
# f16_kv=True, # Use half-precision for KV cache
# verbose=False,
# )
model = ChatLlamaCpp(
model_path=str(model_file),
n_ctx=8096, # Small context to fit ~2GB total RAM usage [web:14]
n_batch=512, # Smaller batch for low memory throughput
n_threads=4, # Conservative threads (avoid RAM thrashing on 4GB) [web:12]
max_tokens= settings.local_max_tokens, # Short responses keep memory low
temperature=0.1, # Focused output, less randomness
top_p=0.9,
top_k=30,
repeat_penalty=1.05,
f16_kv=True, # Essential half-precision KV cache [web:14]
f16=True, # Full f16 where possible
verbose=True,
chat_format="chatml", # Proper templating
# Low-RAM must-haves:
numa=False, # Disable NUMA for single-CPU setups
use_mlock=False, # Skip memory locking (saves overhead)
use_mmap=True, # Memory-map model file (streams from disk)
)
# model = ChatLlamaCpp(
# model_path=str(model_file),
# n_ctx=4096, # Small context to fit ~2GB total RAM usage [web:14]
# n_batch=512, # Smaller batch for low memory throughput
# n_threads=4, # Conservative threads (avoid RAM thrashing on 4GB) [web:12]
# max_tokens= settings.local_max_tokens, # Short responses keep memory low
# temperature=0.1, # Focused output, less randomness
# top_p=0.9,
# min_p=0.15,
# top_k=30,
# repeat_penalty=1.05,
# f16_kv=True, # Essential half-precision KV cache [web:14]
# f16=True, # Full f16 where possible
# verbose=False,
# chat_format="qwen", # Proper templating,
# verbos=True
# )
logger.info(f"Loaded local model: {settings.local_model_name}")
return model
except Exception as e:
logger.error(f"Failed to load local model: {e}")
raise
def get_llm_model(provider: Optional[str] = None):
"""
Get LLM model based on configuration with fallback support.
Args:
provider: Override the default provider ("gemini" or "local")
If None, uses settings.llm_provider
Returns:
LLM model instance (Gemini or Local)
Raises:
RuntimeError: If all models fail to load
"""
provider = provider or settings.llm_provider
if provider == "gemini":
print("gemini loaded")
try:
return get_gemini_model()
except Exception as e:
logger.warning(f"Gemini model failed: {e}")
if settings.enable_fallback:
logger.info("Falling back to local model...")
return get_local_model()
raise
elif provider == "local":
print("local loaded")
try:
return get_local_model()
except Exception as e:
logger.warning(f"Local model failed: {e}")
if settings.enable_fallback:
logger.info("Falling back to Gemini model...")
return get_gemini_model()
raise
else:
raise ValueError(f"Unknown provider: {provider}. Use 'gemini' or 'local'")