agent-project / core /models.py
ego
fix: improve graph generation reliability (strip thinking blocks, add retry, robust DOT extraction)
1fe2fca
import os
import streamlit as st
from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings
def get_llm(model_name: str = "nvidia/nemotron-3-nano-30b-a3b"):
api_key = os.getenv("NV_API_KEY")
if not api_key and "NV_API_KEY" in st.secrets:
api_key = st.secrets["NV_API_KEY"]
if not api_key:
raise ValueError("NVIDIA API Key not found in environment or secrets.")
return ChatNVIDIA(
model=model_name,
temperature=0,
seed=42,
max_tokens=16384,
extra_body={"chat_template_kwargs": {"enable_thinking":True}},
api_key=api_key
)
def get_embeddings():
api_key = os.getenv("NV_API_KEY")
if not api_key and "NV_API_KEY" in st.secrets:
api_key = st.secrets["NV_API_KEY"]
if not api_key:
raise ValueError("NV_API_KEY not found in environment or secrets.")
return NVIDIAEmbeddings(model="nvidia/llama-nemotron-embed-1b-v2", api_key=api_key)
def generate_podcast_audio(script_text: str):
"""Generate podcast audio using NVIDIA Riva TTS hosted API (magpie-tts-multilingual).
Single-voice synthesis: speaker labels (Alex:/Jamie:) are removed.
Returns PCM audio bytes (16-bit, mono, 22050 Hz) or None if failed.
"""
api_key = os.getenv("NV_API_KEY")
if not api_key and "NV_API_KEY" in st.secrets:
api_key = st.secrets["NV_API_KEY"]
if not api_key:
print("NV_API_KEY not found for TTS")
return None
try:
from riva.client import Auth
from riva.client import TTSService
# Setup authentication for NVIDIA hosted Riva TTS
metadata = [
("function-id", "877104f7-e885-42b9-8de8-f6e4c6303969"),
("authorization", f"Bearer {api_key}")
]
auth = Auth(None, True, "grpc.nvcf.nvidia.com:443", metadata)
tts_service = TTSService(auth)
# Remove speaker labels for single-voice synthesis
lines = script_text.split('\n')
clean_lines = []
for line in lines:
line = line.strip()
if not line:
continue
# Remove speaker labels (Alex:/Jamie:)
if line.startswith("Alex:") or line.startswith("Jamie:"):
clean_lines.append(line.split(':', 1)[1].strip())
else:
clean_lines.append(line)
clean_text = ' '.join(clean_lines)
# Call TTS API - use keyword args matching gRPC SynthesizeSpeechRequest
resp = tts_service.synthesize(
text=clean_text,
language_code="en-US",
encoding=1, # LINEAR_PCM
sample_rate_hz=22050,
voice_name="Magpie-Multilingual.EN-US.Aria"
)
# resp.audio contains PCM bytes (16-bit, mono)
return resp.audio
except Exception as e:
print(f"NVIDIA Riva TTS failed: {e}")
import traceback
traceback.print_exc()
return None