File size: 2,412 Bytes
ebb8326
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
"""LLM utility functions using Groq API (OpenAI-compatible)."""

from langchain_core.language_models.chat_models import BaseChatModel
from langchain_openai import ChatOpenAI

from src.config import settings
from src.utils.logging import log_pipeline

_model_cache: dict[str, BaseChatModel] = {}
_override_large_model: str = None  # Global override for large model


def set_large_model_override(model_name: str = None):
    """Set global override for large model."""
    global _override_large_model
    _override_large_model = model_name


def get_small_model() -> BaseChatModel:
    """Get or create small LLM (for routing, reranking, RAG) - using Groq."""
    cache_key = "small"
    if cache_key in _model_cache:
        return _model_cache[cache_key]

    model = ChatOpenAI(
        model=settings.model_small,
        api_key=settings.groq_api_key,
        base_url=settings.groq_base_url,
        temperature=0.6,
        max_tokens=4096,
    )
    
    _model_cache[cache_key] = model
    log_pipeline(f"[Model] Small model initialized: {settings.model_small} (Groq)")
    return model


def get_large_model(model_name: str = None) -> BaseChatModel:
    """Get or create large LLM (for logic/direct answering) - using Groq.
    
    Args:
        model_name: Optional model name to override default
    """
    # Use override if set, otherwise use parameter or default
    effective_model = _override_large_model or model_name or settings.model_large
    cache_key = f"large_{effective_model}"
    
    # Always recreate if override is set (for Streamlit model switching)
    if _override_large_model and cache_key in _model_cache:
        del _model_cache[cache_key]
    
    if cache_key in _model_cache:
        return _model_cache[cache_key]

    # Use higher temperature for GPT-OSS-120B to encourage reasoning
    temperature = 0.3 if "gpt-oss-120b" in effective_model.lower() else 0.0
    
    model = ChatOpenAI(
        model=effective_model,
        api_key=settings.groq_api_key,
        base_url=settings.groq_base_url,
        temperature=temperature,
        max_tokens=2048,
    )
    
    _model_cache[cache_key] = model
    log_pipeline(f"[Model] Large model initialized: {effective_model} (Groq, temp={temperature})")
    return model


def get_available_large_models() -> list[str]:
    """Get list of available large models for testing."""
    return settings.available_large_models