""" Model factory for creating LLM and embedding models. Handles model switching and fallback logic. """ from typing import Optional from pathlib import Path from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings from langchain_community.chat_models import ChatLlamaCpp from app.core.config import settings import logging logger = logging.getLogger(__name__) def get_embedding_model(): """ Get the embedding model (currently only Gemini). Returns: GoogleGenerativeAIEmbeddings: Embedding model instance """ try: embeddings = GoogleGenerativeAIEmbeddings( model=settings.embedding_model_name, google_api_key=settings.google_api_key ) logger.info(f"Loaded embedding model: {settings.embedding_model_name}") return embeddings except Exception as e: logger.error(f"Failed to load embedding model: {e}") raise def get_gemini_model(): """ Get Google Gemini chat model. Returns: ChatGoogleGenerativeAI: Gemini model instance """ try: model = ChatGoogleGenerativeAI( model=settings.gemini_model_name, google_api_key=settings.google_api_key, ) logger.info(f"Loaded Gemini model: {settings.gemini_model_name}") return model except Exception as e: logger.error(f"Failed to load Gemini model: {e}") raise def get_local_model(): """ Get local Qwen model (LlamaCpp). Returns: ChatLlamaCpp: Local model instance """ try: model_file = settings.model_path / settings.local_model_name if not model_file.exists(): raise FileNotFoundError( f"Model file not found: {model_file}\n" f"Please download it to {settings.model_path}/" ) # model = ChatLlamaCpp( # model_path=str(model_file), # n_ctx=4096, # Context window size # n_batch=512, # Batch size for prompt processing # n_threads=4, # Number of CPU threads # max_tokens=settings.local_max_tokens, # Maximum tokens to generate # temperature=0.05, # Low temperature for more focused responses # top_p=0.8, # Nucleus sampling # top_k=20, # Top-k sampling # repeat_penalty=1.1, # Penalty for repetition # f16_kv=True, # Use half-precision for KV cache # verbose=False, # ) model = ChatLlamaCpp( model_path=str(model_file), n_ctx=8096, # Small context to fit ~2GB total RAM usage [web:14] n_batch=512, # Smaller batch for low memory throughput n_threads=4, # Conservative threads (avoid RAM thrashing on 4GB) [web:12] max_tokens= settings.local_max_tokens, # Short responses keep memory low temperature=0.1, # Focused output, less randomness top_p=0.9, top_k=30, repeat_penalty=1.05, f16_kv=True, # Essential half-precision KV cache [web:14] f16=True, # Full f16 where possible verbose=True, chat_format="chatml", # Proper templating # Low-RAM must-haves: numa=False, # Disable NUMA for single-CPU setups use_mlock=False, # Skip memory locking (saves overhead) use_mmap=True, # Memory-map model file (streams from disk) ) # model = ChatLlamaCpp( # model_path=str(model_file), # n_ctx=4096, # Small context to fit ~2GB total RAM usage [web:14] # n_batch=512, # Smaller batch for low memory throughput # n_threads=4, # Conservative threads (avoid RAM thrashing on 4GB) [web:12] # max_tokens= settings.local_max_tokens, # Short responses keep memory low # temperature=0.1, # Focused output, less randomness # top_p=0.9, # min_p=0.15, # top_k=30, # repeat_penalty=1.05, # f16_kv=True, # Essential half-precision KV cache [web:14] # f16=True, # Full f16 where possible # verbose=False, # chat_format="qwen", # Proper templating, # verbos=True # ) logger.info(f"Loaded local model: {settings.local_model_name}") return model except Exception as e: logger.error(f"Failed to load local model: {e}") raise def get_llm_model(provider: Optional[str] = None): """ Get LLM model based on configuration with fallback support. Args: provider: Override the default provider ("gemini" or "local") If None, uses settings.llm_provider Returns: LLM model instance (Gemini or Local) Raises: RuntimeError: If all models fail to load """ provider = provider or settings.llm_provider if provider == "gemini": print("gemini loaded") try: return get_gemini_model() except Exception as e: logger.warning(f"Gemini model failed: {e}") if settings.enable_fallback: logger.info("Falling back to local model...") return get_local_model() raise elif provider == "local": print("local loaded") try: return get_local_model() except Exception as e: logger.warning(f"Local model failed: {e}") if settings.enable_fallback: logger.info("Falling back to Gemini model...") return get_gemini_model() raise else: raise ValueError(f"Unknown provider: {provider}. Use 'gemini' or 'local'")