Agentic-RagBot / src /llm_config.py
Nikhil Pravin Pise
Refactor: Improve code quality, security, and configuration
ad2e847
raw
history blame
11 kB
"""
MediGuard AI RAG-Helper
LLM configuration and initialization
Supports multiple providers:
- Groq (FREE, fast, llama-3.3-70b) - RECOMMENDED
- Google Gemini (FREE tier)
- Ollama (local, for offline use)
"""
import os
import threading
from typing import Literal, Optional
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Configure LangSmith tracing
os.environ["LANGCHAIN_PROJECT"] = os.getenv("LANGCHAIN_PROJECT", "MediGuard_AI_RAG_Helper")
# Default provider (can be overridden via env)
DEFAULT_LLM_PROVIDER = os.getenv("LLM_PROVIDER", "groq")
def get_chat_model(
provider: Optional[Literal["groq", "gemini", "ollama"]] = None,
model: Optional[str] = None,
temperature: float = 0.0,
json_mode: bool = False
):
"""
Get a chat model from the specified provider.
Args:
provider: "groq" (free, fast), "gemini" (free), or "ollama" (local)
model: Model name (provider-specific)
temperature: Sampling temperature
json_mode: Whether to enable JSON output mode
Returns:
LangChain chat model instance
"""
provider = provider or DEFAULT_LLM_PROVIDER
if provider == "groq":
from langchain_groq import ChatGroq
api_key = os.getenv("GROQ_API_KEY")
if not api_key:
raise ValueError(
"GROQ_API_KEY not found in environment.\n"
"Get your FREE API key at: https://console.groq.com/keys"
)
# Default to llama-3.3-70b for best quality (free on Groq)
model = model or "llama-3.3-70b-versatile"
return ChatGroq(
model=model,
temperature=temperature,
api_key=api_key,
model_kwargs={"response_format": {"type": "json_object"}} if json_mode else {}
)
elif provider == "gemini":
from langchain_google_genai import ChatGoogleGenerativeAI
api_key = os.getenv("GOOGLE_API_KEY")
if not api_key:
raise ValueError(
"GOOGLE_API_KEY not found in environment.\n"
"Get your FREE API key at: https://aistudio.google.com/app/apikey"
)
# Default to Gemini 2.0 Flash (fast and free)
model = model or "gemini-2.0-flash"
return ChatGoogleGenerativeAI(
model=model,
temperature=temperature,
google_api_key=api_key,
convert_system_message_to_human=True
)
elif provider == "ollama":
try:
from langchain_ollama import ChatOllama
except ImportError:
from langchain_community.chat_models import ChatOllama
model = model or "llama3.1:8b"
return ChatOllama(
model=model,
temperature=temperature,
format='json' if json_mode else None
)
else:
raise ValueError(f"Unknown provider: {provider}. Use 'groq', 'gemini', or 'ollama'")
def get_embedding_model(provider: Optional[Literal["google", "huggingface", "ollama"]] = None):
"""
Get embedding model for vector search.
Args:
provider: "google" (free, recommended), "huggingface" (local), or "ollama" (local)
Returns:
LangChain embedding model instance
"""
provider = provider or os.getenv("EMBEDDING_PROVIDER", "google")
if provider == "google":
from langchain_google_genai import GoogleGenerativeAIEmbeddings
api_key = os.getenv("GOOGLE_API_KEY")
if not api_key:
print("WARN: GOOGLE_API_KEY not found. Falling back to HuggingFace embeddings.")
return get_embedding_model("huggingface")
try:
return GoogleGenerativeAIEmbeddings(
model="models/text-embedding-004",
google_api_key=api_key
)
except Exception as e:
print(f"WARN: Google embeddings failed: {e}")
print("INFO: Falling back to HuggingFace embeddings...")
return get_embedding_model("huggingface")
elif provider == "huggingface":
try:
from langchain_huggingface import HuggingFaceEmbeddings
except ImportError:
from langchain_community.embeddings import HuggingFaceEmbeddings
return HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2"
)
elif provider == "ollama":
try:
from langchain_ollama import OllamaEmbeddings
except ImportError:
from langchain_community.embeddings import OllamaEmbeddings
return OllamaEmbeddings(model="nomic-embed-text")
else:
raise ValueError(f"Unknown embedding provider: {provider}")
class LLMConfig:
"""Central configuration for all LLM models"""
def __init__(self, provider: Optional[str] = None, lazy: bool = True):
"""
Initialize all model clients.
Args:
provider: LLM provider - "groq" (free), "gemini" (free), or "ollama" (local)
lazy: If True, defer model initialization until first use (avoids API key errors at import)
"""
self.provider = provider or DEFAULT_LLM_PROVIDER
self._lazy = lazy
self._initialized = False
self._lock = threading.Lock()
# Lazy-initialized model instances
self._planner = None
self._analyzer = None
self._explainer = None
self._synthesizer_7b = None
self._synthesizer_8b = None
self._director = None
self._embedding_model = None
if not lazy:
self._initialize_models()
def _initialize_models(self):
"""Initialize all model clients (called on first use if lazy)"""
if self._initialized:
return
with self._lock:
# Double-checked locking
if self._initialized:
return
print(f"Initializing LLM models with provider: {self.provider.upper()}")
# Fast model for structured tasks (planning, analysis)
self._planner = get_chat_model(
provider=self.provider,
temperature=0.0,
json_mode=True
)
# Fast model for biomarker analysis and quick tasks
self._analyzer = get_chat_model(
provider=self.provider,
temperature=0.0
)
# Medium model for RAG retrieval and explanation
self._explainer = get_chat_model(
provider=self.provider,
temperature=0.2
)
# Configurable synthesizers
self._synthesizer_7b = get_chat_model(
provider=self.provider,
temperature=0.2
)
self._synthesizer_8b = get_chat_model(
provider=self.provider,
temperature=0.2
)
# Director for Outer Loop
self._director = get_chat_model(
provider=self.provider,
temperature=0.0,
json_mode=True
)
# Embedding model for RAG
self._embedding_model = get_embedding_model()
self._initialized = True
@property
def planner(self):
self._initialize_models()
return self._planner
@property
def analyzer(self):
self._initialize_models()
return self._analyzer
@property
def explainer(self):
self._initialize_models()
return self._explainer
@property
def synthesizer_7b(self):
self._initialize_models()
return self._synthesizer_7b
@property
def synthesizer_8b(self):
self._initialize_models()
return self._synthesizer_8b
@property
def director(self):
self._initialize_models()
return self._director
@property
def embedding_model(self):
self._initialize_models()
return self._embedding_model
def get_synthesizer(self, model_name: Optional[str] = None):
"""Get synthesizer model (for backward compatibility)"""
if model_name:
return get_chat_model(provider=self.provider, model=model_name, temperature=0.2)
return self.synthesizer_8b
def print_config(self):
"""Print current LLM configuration"""
print("=" * 60)
print("MediGuard AI RAG-Helper - LLM Configuration")
print("=" * 60)
print(f"Provider: {self.provider.upper()}")
if self.provider == "groq":
print(f"Model: llama-3.3-70b-versatile (FREE)")
elif self.provider == "gemini":
print(f"Model: gemini-2.0-flash (FREE)")
else:
print(f"Model: llama3.1:8b (local)")
print(f"Embeddings: Google Gemini (FREE)")
print("=" * 60)
# Global LLM configuration instance
llm_config = LLMConfig()
def check_api_connection():
"""Verify API connection and keys are configured"""
provider = DEFAULT_LLM_PROVIDER
try:
if provider == "groq":
api_key = os.getenv("GROQ_API_KEY")
if not api_key:
print("WARN: GROQ_API_KEY not set")
print("\n Get your FREE API key at:")
print(" https://console.groq.com/keys")
return False
# Test connection
test_model = get_chat_model("groq")
response = test_model.invoke("Say 'OK' in one word")
print("OK: Groq API connection successful")
return True
elif provider == "gemini":
api_key = os.getenv("GOOGLE_API_KEY")
if not api_key:
print("WARN: GOOGLE_API_KEY not set")
print("\n Get your FREE API key at:")
print(" https://aistudio.google.com/app/apikey")
return False
test_model = get_chat_model("gemini")
response = test_model.invoke("Say 'OK' in one word")
print("OK: Google Gemini API connection successful")
return True
else:
try:
from langchain_ollama import ChatOllama
except ImportError:
from langchain_community.chat_models import ChatOllama
test_model = ChatOllama(model="llama3.1:8b")
response = test_model.invoke("Hello")
print("OK: Ollama connection successful")
return True
except Exception as e:
print(f"ERROR: Connection failed: {e}")
return False
if __name__ == "__main__":
# Test configuration
llm_config.print_config()
check_api_connection()