Spaces:
Sleeping
Sleeping
| """ | |
| MediGuard AI RAG-Helper | |
| LLM configuration and initialization | |
| Supports multiple providers: | |
| - Groq (FREE, fast, llama-3.3-70b) - RECOMMENDED | |
| - Google Gemini (FREE tier) | |
| - Ollama (local, for offline use) | |
| """ | |
| import os | |
| import threading | |
| from typing import Literal, Optional | |
| from dotenv import load_dotenv | |
| # Load environment variables | |
| load_dotenv() | |
| # Configure LangSmith tracing | |
| os.environ["LANGCHAIN_PROJECT"] = os.getenv("LANGCHAIN_PROJECT", "MediGuard_AI_RAG_Helper") | |
| # Default provider (can be overridden via env) | |
| DEFAULT_LLM_PROVIDER = os.getenv("LLM_PROVIDER", "groq") | |
| def get_chat_model( | |
| provider: Optional[Literal["groq", "gemini", "ollama"]] = None, | |
| model: Optional[str] = None, | |
| temperature: float = 0.0, | |
| json_mode: bool = False | |
| ): | |
| """ | |
| Get a chat model from the specified provider. | |
| Args: | |
| provider: "groq" (free, fast), "gemini" (free), or "ollama" (local) | |
| model: Model name (provider-specific) | |
| temperature: Sampling temperature | |
| json_mode: Whether to enable JSON output mode | |
| Returns: | |
| LangChain chat model instance | |
| """ | |
| provider = provider or DEFAULT_LLM_PROVIDER | |
| if provider == "groq": | |
| from langchain_groq import ChatGroq | |
| api_key = os.getenv("GROQ_API_KEY") | |
| if not api_key: | |
| raise ValueError( | |
| "GROQ_API_KEY not found in environment.\n" | |
| "Get your FREE API key at: https://console.groq.com/keys" | |
| ) | |
| # Default to llama-3.3-70b for best quality (free on Groq) | |
| model = model or "llama-3.3-70b-versatile" | |
| return ChatGroq( | |
| model=model, | |
| temperature=temperature, | |
| api_key=api_key, | |
| model_kwargs={"response_format": {"type": "json_object"}} if json_mode else {} | |
| ) | |
| elif provider == "gemini": | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| api_key = os.getenv("GOOGLE_API_KEY") | |
| if not api_key: | |
| raise ValueError( | |
| "GOOGLE_API_KEY not found in environment.\n" | |
| "Get your FREE API key at: https://aistudio.google.com/app/apikey" | |
| ) | |
| # Default to Gemini 2.0 Flash (fast and free) | |
| model = model or "gemini-2.0-flash" | |
| return ChatGoogleGenerativeAI( | |
| model=model, | |
| temperature=temperature, | |
| google_api_key=api_key, | |
| convert_system_message_to_human=True | |
| ) | |
| elif provider == "ollama": | |
| try: | |
| from langchain_ollama import ChatOllama | |
| except ImportError: | |
| from langchain_community.chat_models import ChatOllama | |
| model = model or "llama3.1:8b" | |
| return ChatOllama( | |
| model=model, | |
| temperature=temperature, | |
| format='json' if json_mode else None | |
| ) | |
| else: | |
| raise ValueError(f"Unknown provider: {provider}. Use 'groq', 'gemini', or 'ollama'") | |
| def get_embedding_model(provider: Optional[Literal["google", "huggingface", "ollama"]] = None): | |
| """ | |
| Get embedding model for vector search. | |
| Args: | |
| provider: "google" (free, recommended), "huggingface" (local), or "ollama" (local) | |
| Returns: | |
| LangChain embedding model instance | |
| """ | |
| provider = provider or os.getenv("EMBEDDING_PROVIDER", "google") | |
| if provider == "google": | |
| from langchain_google_genai import GoogleGenerativeAIEmbeddings | |
| api_key = os.getenv("GOOGLE_API_KEY") | |
| if not api_key: | |
| print("WARN: GOOGLE_API_KEY not found. Falling back to HuggingFace embeddings.") | |
| return get_embedding_model("huggingface") | |
| try: | |
| return GoogleGenerativeAIEmbeddings( | |
| model="models/text-embedding-004", | |
| google_api_key=api_key | |
| ) | |
| except Exception as e: | |
| print(f"WARN: Google embeddings failed: {e}") | |
| print("INFO: Falling back to HuggingFace embeddings...") | |
| return get_embedding_model("huggingface") | |
| elif provider == "huggingface": | |
| try: | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| except ImportError: | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| return HuggingFaceEmbeddings( | |
| model_name="sentence-transformers/all-MiniLM-L6-v2" | |
| ) | |
| elif provider == "ollama": | |
| try: | |
| from langchain_ollama import OllamaEmbeddings | |
| except ImportError: | |
| from langchain_community.embeddings import OllamaEmbeddings | |
| return OllamaEmbeddings(model="nomic-embed-text") | |
| else: | |
| raise ValueError(f"Unknown embedding provider: {provider}") | |
| class LLMConfig: | |
| """Central configuration for all LLM models""" | |
| def __init__(self, provider: Optional[str] = None, lazy: bool = True): | |
| """ | |
| Initialize all model clients. | |
| Args: | |
| provider: LLM provider - "groq" (free), "gemini" (free), or "ollama" (local) | |
| lazy: If True, defer model initialization until first use (avoids API key errors at import) | |
| """ | |
| self.provider = provider or DEFAULT_LLM_PROVIDER | |
| self._lazy = lazy | |
| self._initialized = False | |
| self._lock = threading.Lock() | |
| # Lazy-initialized model instances | |
| self._planner = None | |
| self._analyzer = None | |
| self._explainer = None | |
| self._synthesizer_7b = None | |
| self._synthesizer_8b = None | |
| self._director = None | |
| self._embedding_model = None | |
| if not lazy: | |
| self._initialize_models() | |
| def _initialize_models(self): | |
| """Initialize all model clients (called on first use if lazy)""" | |
| if self._initialized: | |
| return | |
| with self._lock: | |
| # Double-checked locking | |
| if self._initialized: | |
| return | |
| print(f"Initializing LLM models with provider: {self.provider.upper()}") | |
| # Fast model for structured tasks (planning, analysis) | |
| self._planner = get_chat_model( | |
| provider=self.provider, | |
| temperature=0.0, | |
| json_mode=True | |
| ) | |
| # Fast model for biomarker analysis and quick tasks | |
| self._analyzer = get_chat_model( | |
| provider=self.provider, | |
| temperature=0.0 | |
| ) | |
| # Medium model for RAG retrieval and explanation | |
| self._explainer = get_chat_model( | |
| provider=self.provider, | |
| temperature=0.2 | |
| ) | |
| # Configurable synthesizers | |
| self._synthesizer_7b = get_chat_model( | |
| provider=self.provider, | |
| temperature=0.2 | |
| ) | |
| self._synthesizer_8b = get_chat_model( | |
| provider=self.provider, | |
| temperature=0.2 | |
| ) | |
| # Director for Outer Loop | |
| self._director = get_chat_model( | |
| provider=self.provider, | |
| temperature=0.0, | |
| json_mode=True | |
| ) | |
| # Embedding model for RAG | |
| self._embedding_model = get_embedding_model() | |
| self._initialized = True | |
| def planner(self): | |
| self._initialize_models() | |
| return self._planner | |
| def analyzer(self): | |
| self._initialize_models() | |
| return self._analyzer | |
| def explainer(self): | |
| self._initialize_models() | |
| return self._explainer | |
| def synthesizer_7b(self): | |
| self._initialize_models() | |
| return self._synthesizer_7b | |
| def synthesizer_8b(self): | |
| self._initialize_models() | |
| return self._synthesizer_8b | |
| def director(self): | |
| self._initialize_models() | |
| return self._director | |
| def embedding_model(self): | |
| self._initialize_models() | |
| return self._embedding_model | |
| def get_synthesizer(self, model_name: Optional[str] = None): | |
| """Get synthesizer model (for backward compatibility)""" | |
| if model_name: | |
| return get_chat_model(provider=self.provider, model=model_name, temperature=0.2) | |
| return self.synthesizer_8b | |
| def print_config(self): | |
| """Print current LLM configuration""" | |
| print("=" * 60) | |
| print("MediGuard AI RAG-Helper - LLM Configuration") | |
| print("=" * 60) | |
| print(f"Provider: {self.provider.upper()}") | |
| if self.provider == "groq": | |
| print(f"Model: llama-3.3-70b-versatile (FREE)") | |
| elif self.provider == "gemini": | |
| print(f"Model: gemini-2.0-flash (FREE)") | |
| else: | |
| print(f"Model: llama3.1:8b (local)") | |
| print(f"Embeddings: Google Gemini (FREE)") | |
| print("=" * 60) | |
| # Global LLM configuration instance | |
| llm_config = LLMConfig() | |
| def check_api_connection(): | |
| """Verify API connection and keys are configured""" | |
| provider = DEFAULT_LLM_PROVIDER | |
| try: | |
| if provider == "groq": | |
| api_key = os.getenv("GROQ_API_KEY") | |
| if not api_key: | |
| print("WARN: GROQ_API_KEY not set") | |
| print("\n Get your FREE API key at:") | |
| print(" https://console.groq.com/keys") | |
| return False | |
| # Test connection | |
| test_model = get_chat_model("groq") | |
| response = test_model.invoke("Say 'OK' in one word") | |
| print("OK: Groq API connection successful") | |
| return True | |
| elif provider == "gemini": | |
| api_key = os.getenv("GOOGLE_API_KEY") | |
| if not api_key: | |
| print("WARN: GOOGLE_API_KEY not set") | |
| print("\n Get your FREE API key at:") | |
| print(" https://aistudio.google.com/app/apikey") | |
| return False | |
| test_model = get_chat_model("gemini") | |
| response = test_model.invoke("Say 'OK' in one word") | |
| print("OK: Google Gemini API connection successful") | |
| return True | |
| else: | |
| try: | |
| from langchain_ollama import ChatOllama | |
| except ImportError: | |
| from langchain_community.chat_models import ChatOllama | |
| test_model = ChatOllama(model="llama3.1:8b") | |
| response = test_model.invoke("Hello") | |
| print("OK: Ollama connection successful") | |
| return True | |
| except Exception as e: | |
| print(f"ERROR: Connection failed: {e}") | |
| return False | |
| if __name__ == "__main__": | |
| # Test configuration | |
| llm_config.print_config() | |
| check_api_connection() | |