File size: 3,377 Bytes
9af190b
 
0bfc688
affeafa
9af190b
090f2a4
 
0bfc688
affeafa
 
090f2a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0bfc688
 
 
 
 
 
 
 
 
 
3d758c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0bfc688
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9af190b
090f2a4
 
affeafa
9af190b
 
affeafa
9af190b
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from langchain_nvidia_ai_endpoints import ChatNVIDIA
from langgraph.graph import StateGraph, MessagesState, START, END
from livekit.plugins import nvidia

from src.core.settings import settings
from src.core.logger import logger
from src.plugins.huggingface_llm import HuggingFaceLLM
from src.plugins.moonshine_stt import MoonshineSTT


def create_llm():
    """Create an LLM instance based on the configured provider."""
    provider = settings.llm.LLM_PROVIDER.lower()

    if provider == "nvidia":
        logger.info(f"Initializing NVIDIA LLM: {settings.llm.NVIDIA_MODEL}")
        return ChatNVIDIA(
            model=settings.llm.NVIDIA_MODEL,
            api_key=settings.llm.NVIDIA_API_KEY,
            temperature=settings.llm.LLM_TEMPERATURE,
            max_tokens=settings.llm.LLM_MAX_TOKENS,
        )

    elif provider == "huggingface":
        logger.info(f"Initializing HuggingFace LLM: {settings.llm.HUGGINGFACE_MODEL_ID}")
        return HuggingFaceLLM(
            model_id=settings.llm.HUGGINGFACE_MODEL_ID,
            device=settings.llm.HUGGINGFACE_DEVICE,
            temperature=settings.llm.LLM_TEMPERATURE,
            max_tokens=settings.llm.LLM_MAX_TOKENS,
            top_p=0.95,
            repetition_penalty=1.0,
        )

    else:
        raise ValueError(f"Unknown LLM provider: {provider}. Must be 'nvidia' or 'huggingface'")


def create_stt():
    """Create an STT instance based on the configured provider."""
    provider = settings.stt.STT_PROVIDER.lower()

    if provider == "nvidia":
        logger.info(
            f"Initializing NVIDIA STT: {settings.stt.NVIDIA_STT_MODEL} "
            f"(language: {settings.stt.NVIDIA_STT_LANGUAGE_CODE})"
        )
        if settings.stt.NVIDIA_STT_API_KEY:
            api_key = settings.stt.NVIDIA_STT_API_KEY
            key_source = "NVIDIA_STT_API_KEY"
        elif settings.llm.NVIDIA_API_KEY:
            api_key = settings.llm.NVIDIA_API_KEY
            key_source = "NVIDIA_API_KEY"
        else:
            api_key = None
            key_source = "not_set"

        logger.info("NVIDIA STT auth source: %s", key_source)
        if not api_key:
            logger.warning(
                "NVIDIA STT is configured but no API key is set (NVIDIA_STT_API_KEY/NVIDIA_API_KEY)"
            )

        return nvidia.STT(
            language_code=settings.stt.NVIDIA_STT_LANGUAGE_CODE,
            model=settings.stt.NVIDIA_STT_MODEL,
            api_key=api_key,
        )

    elif provider == "moonshine":
        logger.info(
            f"Initializing Moonshine STT: {settings.stt.MOONSHINE_MODEL_ID} "
            f"(language: {settings.stt.MOONSHINE_LANGUAGE})"
        )
        return MoonshineSTT(
            model_id=settings.stt.MOONSHINE_MODEL_ID,
            language=settings.stt.MOONSHINE_LANGUAGE,
        )

    else:
        raise ValueError(
            f"Unknown STT provider: {provider}. Must be 'nvidia' or 'moonshine'"
        )


def create_graph():
    """Create a single-node LangGraph workflow with the configured LLM provider."""
    llm = create_llm()

    def call_model(state: MessagesState) -> dict:
        return {"messages": [llm.invoke(state["messages"])]}

    workflow = StateGraph(MessagesState)
    workflow.add_node("agent", call_model)
    workflow.add_edge(START, "agent")
    workflow.add_edge("agent", END)
    return workflow.compile()