Spaces:
Running
Running
| """ | |
| Model factory for creating LLM and embedding models. | |
| Handles model switching and fallback logic. | |
| """ | |
| from typing import Optional | |
| from pathlib import Path | |
| from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings | |
| from langchain_community.chat_models import ChatLlamaCpp | |
| from app.core.config import settings | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| def get_embedding_model(): | |
| """ | |
| Get the embedding model (currently only Gemini). | |
| Returns: | |
| GoogleGenerativeAIEmbeddings: Embedding model instance | |
| """ | |
| try: | |
| embeddings = GoogleGenerativeAIEmbeddings( | |
| model=settings.embedding_model_name, | |
| google_api_key=settings.google_api_key | |
| ) | |
| logger.info(f"Loaded embedding model: {settings.embedding_model_name}") | |
| return embeddings | |
| except Exception as e: | |
| logger.error(f"Failed to load embedding model: {e}") | |
| raise | |
| def get_gemini_model(): | |
| """ | |
| Get Google Gemini chat model. | |
| Returns: | |
| ChatGoogleGenerativeAI: Gemini model instance | |
| """ | |
| try: | |
| model = ChatGoogleGenerativeAI( | |
| model=settings.gemini_model_name, | |
| google_api_key=settings.google_api_key, | |
| ) | |
| logger.info(f"Loaded Gemini model: {settings.gemini_model_name}") | |
| return model | |
| except Exception as e: | |
| logger.error(f"Failed to load Gemini model: {e}") | |
| raise | |
| def get_local_model(): | |
| """ | |
| Get local Qwen model (LlamaCpp). | |
| Returns: | |
| ChatLlamaCpp: Local model instance | |
| """ | |
| try: | |
| model_file = settings.model_path / settings.local_model_name | |
| if not model_file.exists(): | |
| raise FileNotFoundError( | |
| f"Model file not found: {model_file}\n" | |
| f"Please download it to {settings.model_path}/" | |
| ) | |
| # model = ChatLlamaCpp( | |
| # model_path=str(model_file), | |
| # n_ctx=4096, # Context window size | |
| # n_batch=512, # Batch size for prompt processing | |
| # n_threads=4, # Number of CPU threads | |
| # max_tokens=settings.local_max_tokens, # Maximum tokens to generate | |
| # temperature=0.05, # Low temperature for more focused responses | |
| # top_p=0.8, # Nucleus sampling | |
| # top_k=20, # Top-k sampling | |
| # repeat_penalty=1.1, # Penalty for repetition | |
| # f16_kv=True, # Use half-precision for KV cache | |
| # verbose=False, | |
| # ) | |
| model = ChatLlamaCpp( | |
| model_path=str(model_file), | |
| n_ctx=8096, # Small context to fit ~2GB total RAM usage [web:14] | |
| n_batch=512, # Smaller batch for low memory throughput | |
| n_threads=4, # Conservative threads (avoid RAM thrashing on 4GB) [web:12] | |
| max_tokens= settings.local_max_tokens, # Short responses keep memory low | |
| temperature=0.1, # Focused output, less randomness | |
| top_p=0.9, | |
| top_k=30, | |
| repeat_penalty=1.05, | |
| f16_kv=True, # Essential half-precision KV cache [web:14] | |
| f16=True, # Full f16 where possible | |
| verbose=True, | |
| chat_format="chatml", # Proper templating | |
| # Low-RAM must-haves: | |
| numa=False, # Disable NUMA for single-CPU setups | |
| use_mlock=False, # Skip memory locking (saves overhead) | |
| use_mmap=True, # Memory-map model file (streams from disk) | |
| ) | |
| # model = ChatLlamaCpp( | |
| # model_path=str(model_file), | |
| # n_ctx=4096, # Small context to fit ~2GB total RAM usage [web:14] | |
| # n_batch=512, # Smaller batch for low memory throughput | |
| # n_threads=4, # Conservative threads (avoid RAM thrashing on 4GB) [web:12] | |
| # max_tokens= settings.local_max_tokens, # Short responses keep memory low | |
| # temperature=0.1, # Focused output, less randomness | |
| # top_p=0.9, | |
| # min_p=0.15, | |
| # top_k=30, | |
| # repeat_penalty=1.05, | |
| # f16_kv=True, # Essential half-precision KV cache [web:14] | |
| # f16=True, # Full f16 where possible | |
| # verbose=False, | |
| # chat_format="qwen", # Proper templating, | |
| # verbos=True | |
| # ) | |
| logger.info(f"Loaded local model: {settings.local_model_name}") | |
| return model | |
| except Exception as e: | |
| logger.error(f"Failed to load local model: {e}") | |
| raise | |
| def get_llm_model(provider: Optional[str] = None): | |
| """ | |
| Get LLM model based on configuration with fallback support. | |
| Args: | |
| provider: Override the default provider ("gemini" or "local") | |
| If None, uses settings.llm_provider | |
| Returns: | |
| LLM model instance (Gemini or Local) | |
| Raises: | |
| RuntimeError: If all models fail to load | |
| """ | |
| provider = provider or settings.llm_provider | |
| if provider == "gemini": | |
| print("gemini loaded") | |
| try: | |
| return get_gemini_model() | |
| except Exception as e: | |
| logger.warning(f"Gemini model failed: {e}") | |
| if settings.enable_fallback: | |
| logger.info("Falling back to local model...") | |
| return get_local_model() | |
| raise | |
| elif provider == "local": | |
| print("local loaded") | |
| try: | |
| return get_local_model() | |
| except Exception as e: | |
| logger.warning(f"Local model failed: {e}") | |
| if settings.enable_fallback: | |
| logger.info("Falling back to Gemini model...") | |
| return get_gemini_model() | |
| raise | |
| else: | |
| raise ValueError(f"Unknown provider: {provider}. Use 'gemini' or 'local'") | |