File size: 5,931 Bytes
4225666
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
"""
Model factory for creating LLM and embedding models.
Handles model switching and fallback logic.
"""
from typing import Optional
from pathlib import Path
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
from langchain_community.chat_models import ChatLlamaCpp
from app.core.config import settings
import logging

logger = logging.getLogger(__name__)


def get_embedding_model():
    """
    Get the embedding model (currently only Gemini).

    Returns:
        GoogleGenerativeAIEmbeddings: Embedding model instance
    """
    try:
        embeddings = GoogleGenerativeAIEmbeddings(
            model=settings.embedding_model_name,
            google_api_key=settings.google_api_key
        )
        logger.info(f"Loaded embedding model: {settings.embedding_model_name}")
        return embeddings
    except Exception as e:
        logger.error(f"Failed to load embedding model: {e}")
        raise


def get_gemini_model():
    """
    Get Google Gemini chat model.

    Returns:
        ChatGoogleGenerativeAI: Gemini model instance
    """
    try:
        model = ChatGoogleGenerativeAI(
            model=settings.gemini_model_name,
            google_api_key=settings.google_api_key,
        )
        logger.info(f"Loaded Gemini model: {settings.gemini_model_name}")
        return model
    except Exception as e:
        logger.error(f"Failed to load Gemini model: {e}")
        raise


def get_local_model():
    """
    Get local Qwen model (LlamaCpp).

    Returns:
        ChatLlamaCpp: Local model instance
    """
    try:
        model_file = settings.model_path / settings.local_model_name

        if not model_file.exists():
            raise FileNotFoundError(
                f"Model file not found: {model_file}\n"
                f"Please download it to {settings.model_path}/"
            )

        # model = ChatLlamaCpp(
        #     model_path=str(model_file),
        #     n_ctx=4096,          # Context window size
        #     n_batch=512,         # Batch size for prompt processing
        #     n_threads=4,         # Number of CPU threads
        #     max_tokens=settings.local_max_tokens,  # Maximum tokens to generate
        #     temperature=0.05,    # Low temperature for more focused responses
        #     top_p=0.8,          # Nucleus sampling
        #     top_k=20,           # Top-k sampling
        #     repeat_penalty=1.1, # Penalty for repetition
        #     f16_kv=True,        # Use half-precision for KV cache
        #     verbose=False,
        # )
        model = ChatLlamaCpp(
            model_path=str(model_file),
            n_ctx=8096,           # Small context to fit ~2GB total RAM usage [web:14]
            n_batch=512,          # Smaller batch for low memory throughput
            n_threads=4,          # Conservative threads (avoid RAM thrashing on 4GB) [web:12]
            max_tokens= settings.local_max_tokens,       # Short responses keep memory low
            temperature=0.1,      # Focused output, less randomness
            top_p=0.9,
            top_k=30,
            repeat_penalty=1.05,
            f16_kv=True,          # Essential half-precision KV cache [web:14]
            f16=True,             # Full f16 where possible
            verbose=True,
            chat_format="chatml",   # Proper templating
            # Low-RAM must-haves:
            numa=False,           # Disable NUMA for single-CPU setups
            use_mlock=False,      # Skip memory locking (saves overhead)
            use_mmap=True,        # Memory-map model file (streams from disk)
        )
        # model = ChatLlamaCpp(
        #     model_path=str(model_file),
        #     n_ctx=4096,           # Small context to fit ~2GB total RAM usage [web:14]
        #     n_batch=512,          # Smaller batch for low memory throughput
        #     n_threads=4,          # Conservative threads (avoid RAM thrashing on 4GB) [web:12]
        #     max_tokens= settings.local_max_tokens,       # Short responses keep memory low
        #     temperature=0.1,      # Focused output, less randomness
        #     top_p=0.9,
        #     min_p=0.15,
        #     top_k=30,
        #     repeat_penalty=1.05,
        #     f16_kv=True,          # Essential half-precision KV cache [web:14]
        #     f16=True,             # Full f16 where possible
        #     verbose=False,
        #     chat_format="qwen",   # Proper templating,
        #     verbos=True
        # )
        logger.info(f"Loaded local model: {settings.local_model_name}")
        return model
    except Exception as e:
        logger.error(f"Failed to load local model: {e}")
        raise


def get_llm_model(provider: Optional[str] = None):
    """
    Get LLM model based on configuration with fallback support.

    Args:
        provider: Override the default provider ("gemini" or "local")
                 If None, uses settings.llm_provider

    Returns:
        LLM model instance (Gemini or Local)

    Raises:
        RuntimeError: If all models fail to load
    """
    provider = provider or settings.llm_provider

    if provider == "gemini":
        print("gemini loaded")
        try:
            return get_gemini_model()
        except Exception as e:
            logger.warning(f"Gemini model failed: {e}")
            if settings.enable_fallback:
                logger.info("Falling back to local model...")
                return get_local_model()
            raise

    elif provider == "local":
        print("local loaded")
        try:
            return get_local_model()
        except Exception as e:
            logger.warning(f"Local model failed: {e}")
            if settings.enable_fallback:
                logger.info("Falling back to Gemini model...")
                return get_gemini_model()
            raise

    else:
        raise ValueError(f"Unknown provider: {provider}. Use 'gemini' or 'local'")