Spaces:
Running
Running
File size: 3,377 Bytes
9af190b 0bfc688 affeafa 9af190b 090f2a4 0bfc688 affeafa 090f2a4 0bfc688 3d758c5 0bfc688 9af190b 090f2a4 affeafa 9af190b affeafa 9af190b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 | from langchain_nvidia_ai_endpoints import ChatNVIDIA
from langgraph.graph import StateGraph, MessagesState, START, END
from livekit.plugins import nvidia
from src.core.settings import settings
from src.core.logger import logger
from src.plugins.huggingface_llm import HuggingFaceLLM
from src.plugins.moonshine_stt import MoonshineSTT
def create_llm():
"""Create an LLM instance based on the configured provider."""
provider = settings.llm.LLM_PROVIDER.lower()
if provider == "nvidia":
logger.info(f"Initializing NVIDIA LLM: {settings.llm.NVIDIA_MODEL}")
return ChatNVIDIA(
model=settings.llm.NVIDIA_MODEL,
api_key=settings.llm.NVIDIA_API_KEY,
temperature=settings.llm.LLM_TEMPERATURE,
max_tokens=settings.llm.LLM_MAX_TOKENS,
)
elif provider == "huggingface":
logger.info(f"Initializing HuggingFace LLM: {settings.llm.HUGGINGFACE_MODEL_ID}")
return HuggingFaceLLM(
model_id=settings.llm.HUGGINGFACE_MODEL_ID,
device=settings.llm.HUGGINGFACE_DEVICE,
temperature=settings.llm.LLM_TEMPERATURE,
max_tokens=settings.llm.LLM_MAX_TOKENS,
top_p=0.95,
repetition_penalty=1.0,
)
else:
raise ValueError(f"Unknown LLM provider: {provider}. Must be 'nvidia' or 'huggingface'")
def create_stt():
"""Create an STT instance based on the configured provider."""
provider = settings.stt.STT_PROVIDER.lower()
if provider == "nvidia":
logger.info(
f"Initializing NVIDIA STT: {settings.stt.NVIDIA_STT_MODEL} "
f"(language: {settings.stt.NVIDIA_STT_LANGUAGE_CODE})"
)
if settings.stt.NVIDIA_STT_API_KEY:
api_key = settings.stt.NVIDIA_STT_API_KEY
key_source = "NVIDIA_STT_API_KEY"
elif settings.llm.NVIDIA_API_KEY:
api_key = settings.llm.NVIDIA_API_KEY
key_source = "NVIDIA_API_KEY"
else:
api_key = None
key_source = "not_set"
logger.info("NVIDIA STT auth source: %s", key_source)
if not api_key:
logger.warning(
"NVIDIA STT is configured but no API key is set (NVIDIA_STT_API_KEY/NVIDIA_API_KEY)"
)
return nvidia.STT(
language_code=settings.stt.NVIDIA_STT_LANGUAGE_CODE,
model=settings.stt.NVIDIA_STT_MODEL,
api_key=api_key,
)
elif provider == "moonshine":
logger.info(
f"Initializing Moonshine STT: {settings.stt.MOONSHINE_MODEL_ID} "
f"(language: {settings.stt.MOONSHINE_LANGUAGE})"
)
return MoonshineSTT(
model_id=settings.stt.MOONSHINE_MODEL_ID,
language=settings.stt.MOONSHINE_LANGUAGE,
)
else:
raise ValueError(
f"Unknown STT provider: {provider}. Must be 'nvidia' or 'moonshine'"
)
def create_graph():
"""Create a single-node LangGraph workflow with the configured LLM provider."""
llm = create_llm()
def call_model(state: MessagesState) -> dict:
return {"messages": [llm.invoke(state["messages"])]}
workflow = StateGraph(MessagesState)
workflow.add_node("agent", call_model)
workflow.add_edge(START, "agent")
workflow.add_edge("agent", END)
return workflow.compile()
|