Agentic-RagBot / src /llm_config.py
T0X1N's picture
chore: codebase audit and fixes (ruff, mypy, pytest)
9659593
"""
MediGuard AI RAG-Helper
LLM configuration and initialization
Supports multiple providers:
- Groq (FREE, fast, llama-3.3-70b) - RECOMMENDED
- Google Gemini (FREE tier)
- Ollama (local, for offline use)
Environment Variables (supports both naming conventions):
- Simple: GROQ_API_KEY, GOOGLE_API_KEY, LLM_PROVIDER, GROQ_MODEL, etc.
- Nested: LLM__GROQ_API_KEY, LLM__GOOGLE_API_KEY, LLM__PROVIDER, etc.
"""
import os
import threading
from typing import Literal
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Configure LangSmith tracing
os.environ["LANGCHAIN_PROJECT"] = os.getenv("LANGCHAIN_PROJECT", "MediGuard_AI_RAG_Helper")
def _get_env_with_fallback(primary: str, fallback: str, default: str = "") -> str:
"""Get env var with fallback to alternate naming convention."""
return os.getenv(primary) or os.getenv(fallback) or default
def get_default_llm_provider() -> str:
"""Get default LLM provider dynamically from environment.
Supports both naming conventions:
- LLM_PROVIDER (simple)
- LLM__PROVIDER (pydantic nested)
"""
return _get_env_with_fallback("LLM_PROVIDER", "LLM__PROVIDER", "groq")
def get_groq_api_key() -> str:
"""Get Groq API key from environment (supports both naming conventions)."""
return _get_env_with_fallback("GROQ_API_KEY", "LLM__GROQ_API_KEY", "")
def get_google_api_key() -> str:
"""Get Google API key from environment (supports both naming conventions)."""
return _get_env_with_fallback("GOOGLE_API_KEY", "LLM__GOOGLE_API_KEY", "")
def get_groq_model() -> str:
"""Get Groq model from environment (supports both naming conventions)."""
return _get_env_with_fallback("GROQ_MODEL", "LLM__GROQ_MODEL", "llama-3.3-70b-versatile")
def get_gemini_model() -> str:
"""Get Gemini model from environment (supports both naming conventions)."""
return _get_env_with_fallback("GEMINI_MODEL", "LLM__GEMINI_MODEL", "gemini-2.0-flash")
# For backward compatibility (but prefer using get_default_llm_provider())
DEFAULT_LLM_PROVIDER = get_default_llm_provider()
def get_chat_model(
provider: Literal["groq", "gemini", "ollama"] | None = None,
model: str | None = None,
temperature: float = 0.0,
json_mode: bool = False,
):
"""
Get a chat model from the specified provider.
Args:
provider: "groq" (free, fast), "gemini" (free), or "ollama" (local)
model: Model name (provider-specific)
temperature: Sampling temperature
json_mode: Whether to enable JSON output mode
Returns:
LangChain chat model instance
"""
# Use dynamic lookup to get current provider from environment
provider = provider or get_default_llm_provider()
if provider == "groq":
from langchain_groq import ChatGroq
api_key = get_groq_api_key()
if not api_key:
raise ValueError(
"GROQ_API_KEY not found in environment.\nGet your FREE API key at: https://console.groq.com/keys"
)
# Use model from environment or default
model = model or get_groq_model()
return ChatGroq(
model=model,
temperature=temperature,
api_key=api_key,
model_kwargs={"response_format": {"type": "json_object"}} if json_mode else {},
)
elif provider == "gemini":
from langchain_google_genai import ChatGoogleGenerativeAI
api_key = get_google_api_key()
if not api_key:
raise ValueError(
"GOOGLE_API_KEY not found in environment.\n"
"Get your FREE API key at: https://aistudio.google.com/app/apikey"
)
# Use model from environment or default
model = model or get_gemini_model()
return ChatGoogleGenerativeAI(
model=model, temperature=temperature, google_api_key=api_key, convert_system_message_to_human=True
)
elif provider == "ollama":
try:
from langchain_ollama import ChatOllama
except ImportError:
from langchain_community.chat_models import ChatOllama
model = model or "llama3.1:8b"
return ChatOllama(model=model, temperature=temperature, format="json" if json_mode else None)
else:
raise ValueError(f"Unknown provider: {provider}. Use 'groq', 'gemini', or 'ollama'")
def get_embedding_provider() -> str:
"""Get embedding provider from environment (supports both naming conventions)."""
return _get_env_with_fallback("EMBEDDING_PROVIDER", "EMBEDDING__PROVIDER", "huggingface")
def get_embedding_model(provider: Literal["jina", "google", "huggingface", "ollama"] | None = None):
"""
Get embedding model for vector search.
Args:
provider: "jina" (high-quality), "google" (free), "huggingface" (local), or "ollama" (local)
Returns:
LangChain embedding model instance
Note:
For production use, prefer src.services.embeddings.service.make_embedding_service()
which has automatic fallback chain: Jina → Google → HuggingFace.
"""
provider = provider or get_embedding_provider()
if provider == "jina":
# Try Jina AI embeddings first (high quality, 1024d)
jina_key = _get_env_with_fallback("JINA_API_KEY", "EMBEDDING__JINA_API_KEY", "")
if jina_key:
try:
# Use the embedding service for Jina
from src.services.embeddings.service import make_embedding_service
return make_embedding_service()
except Exception as e:
print(f"WARN: Jina embeddings failed: {e}")
print("INFO: Falling back to Google embeddings...")
return get_embedding_model("google")
else:
print("WARN: JINA_API_KEY not found. Falling back to Google embeddings.")
return get_embedding_model("google")
elif provider == "google":
from langchain_google_genai import GoogleGenerativeAIEmbeddings
api_key = get_google_api_key()
if not api_key:
print("WARN: GOOGLE_API_KEY not found. Falling back to HuggingFace embeddings.")
return get_embedding_model("huggingface")
try:
return GoogleGenerativeAIEmbeddings(model="models/text-embedding-004", google_api_key=api_key)
except Exception as e:
print(f"WARN: Google embeddings failed: {e}")
print("INFO: Falling back to HuggingFace embeddings...")
return get_embedding_model("huggingface")
elif provider == "huggingface":
try:
from langchain_huggingface import HuggingFaceEmbeddings
except ImportError:
from langchain_community.embeddings import HuggingFaceEmbeddings
return HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
elif provider == "ollama":
try:
from langchain_ollama import OllamaEmbeddings
except ImportError:
from langchain_community.embeddings import OllamaEmbeddings
return OllamaEmbeddings(model="nomic-embed-text")
else:
raise ValueError(f"Unknown embedding provider: {provider}")
class LLMConfig:
"""Central configuration for all LLM models"""
def __init__(self, provider: str | None = None, lazy: bool = True):
"""
Initialize all model clients.
Args:
provider: LLM provider - "groq" (free), "gemini" (free), or "ollama" (local)
lazy: If True, defer model initialization until first use (avoids API key errors at import)
"""
# Store explicit provider or None to use dynamic lookup later
self._explicit_provider = provider
self._lazy = lazy
self._initialized = False
self._initialized_provider = None # Track which provider was initialized
self._lock = threading.Lock()
# Lazy-initialized model instances
self._planner = None
self._analyzer = None
self._explainer = None
self._synthesizer_7b = None
self._synthesizer_8b = None
self._director = None
self._embedding_model = None
if not lazy:
self._initialize_models()
@property
def provider(self) -> str:
"""Get current provider (dynamic lookup if not explicitly set)."""
return self._explicit_provider or get_default_llm_provider()
def _check_provider_change(self):
"""Check if provider changed and reinitialize if needed."""
current = self.provider
if self._initialized and self._initialized_provider != current:
print(f"Provider changed from {self._initialized_provider} to {current}, reinitializing...")
self._initialized = False
self._planner = None
self._analyzer = None
self._explainer = None
self._synthesizer_7b = None
self._synthesizer_8b = None
self._director = None
def _initialize_models(self):
"""Initialize all model clients (called on first use if lazy)"""
self._check_provider_change()
if self._initialized:
return
with self._lock:
# Double-checked locking
if self._initialized:
return
print(f"Initializing LLM models with provider: {self.provider.upper()}")
# Fast model for structured tasks (planning, analysis)
self._planner = get_chat_model(provider=self.provider, temperature=0.0, json_mode=True)
# Fast model for biomarker analysis and quick tasks
self._analyzer = get_chat_model(provider=self.provider, temperature=0.0)
# Medium model for RAG retrieval and explanation
self._explainer = get_chat_model(provider=self.provider, temperature=0.2)
# Configurable synthesizers
self._synthesizer_7b = get_chat_model(provider=self.provider, temperature=0.2)
self._synthesizer_8b = get_chat_model(provider=self.provider, temperature=0.2)
# Director for Outer Loop
self._director = get_chat_model(provider=self.provider, temperature=0.0, json_mode=True)
# Embedding model for RAG
self._embedding_model = get_embedding_model()
self._initialized = True
self._initialized_provider = self.provider
@property
def planner(self):
self._initialize_models()
return self._planner
@property
def analyzer(self):
self._initialize_models()
return self._analyzer
@property
def explainer(self):
self._initialize_models()
return self._explainer
@property
def synthesizer_7b(self):
self._initialize_models()
return self._synthesizer_7b
@property
def synthesizer_8b(self):
self._initialize_models()
return self._synthesizer_8b
@property
def director(self):
self._initialize_models()
return self._director
@property
def embedding_model(self):
self._initialize_models()
return self._embedding_model
def get_synthesizer(self, model_name: str | None = None):
"""Get synthesizer model (for backward compatibility)"""
if model_name:
return get_chat_model(provider=self.provider, model=model_name, temperature=0.2)
return self.synthesizer_8b
def print_config(self):
"""Print current LLM configuration"""
print("=" * 60)
print("MediGuard AI RAG-Helper - LLM Configuration")
print("=" * 60)
print(f"Provider: {self.provider.upper()}")
if self.provider == "groq":
print("Model: llama-3.3-70b-versatile (FREE)")
elif self.provider == "gemini":
print("Model: gemini-2.0-flash (FREE)")
else:
print("Model: llama3.1:8b (local)")
print("Embeddings: Google Gemini (FREE)")
print("=" * 60)
# Global LLM configuration instance
llm_config = LLMConfig()
def get_synthesizer(model_name: str | None = None):
"""Module-level convenience: get a synthesizer LLM instance."""
return llm_config.get_synthesizer(model_name)
def check_api_connection():
"""Verify API connection and keys are configured"""
provider = DEFAULT_LLM_PROVIDER
try:
if provider == "groq":
api_key = os.getenv("GROQ_API_KEY")
if not api_key:
print("WARN: GROQ_API_KEY not set")
print("\n Get your FREE API key at:")
print(" https://console.groq.com/keys")
return False
# Test connection
test_model = get_chat_model("groq")
response = test_model.invoke("Say 'OK' in one word")
print("OK: Groq API connection successful")
return True
elif provider == "gemini":
api_key = os.getenv("GOOGLE_API_KEY")
if not api_key:
print("WARN: GOOGLE_API_KEY not set")
print("\n Get your FREE API key at:")
print(" https://aistudio.google.com/app/apikey")
return False
test_model = get_chat_model("gemini")
response = test_model.invoke("Say 'OK' in one word")
print("OK: Google Gemini API connection successful")
return True
else:
try:
from langchain_ollama import ChatOllama
except ImportError:
from langchain_community.chat_models import ChatOllama
test_model = ChatOllama(model="llama3.1:8b")
response = test_model.invoke("Hello")
print("OK: Ollama connection successful")
return True
except Exception as e:
print(f"ERROR: Connection failed: {e}")
return False
if __name__ == "__main__":
# Test configuration
llm_config.print_config()
check_api_connection()