medrax2 / medrax /models /model_factory.py
samwell's picture
Replace Gemini with MedGemma-4B as main orchestrator
b2fc7a6
"""Factory for creating language model instances based on model name."""
import os
from typing import Dict, Any, Type
from langchain_core.language_models import BaseLanguageModel
from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_xai import ChatXAI
from .medgemma import ChatMedGemma
class ModelFactory:
"""Factory for creating language model instances based on model name.
This class implements a registry of language model providers and provides
methods to create appropriate language model instances based on the model name.
"""
# Registry of model providers
_model_providers = {
"gpt": {
"class": ChatOpenAI,
"env_key": "OPENAI_API_KEY",
"base_url_key": "OPENAI_BASE_URL",
},
"chatgpt": {
"class": ChatOpenAI,
"env_key": "OPENAI_API_KEY",
"base_url_key": "OPENAI_BASE_URL",
},
"gemini": {
"class": ChatGoogleGenerativeAI,
"env_key": "GOOGLE_API_KEY",
"base_url_key": "GOOGLE_BASE_URL",
},
"openrouter": {
"class": ChatOpenAI, # OpenRouter uses OpenAI-compatible interface
"env_key": "OPENROUTER_API_KEY",
"base_url_key": "OPENROUTER_BASE_URL",
"default_base_url": "https://openrouter.ai/api/v1",
},
"grok": {
"class": ChatXAI,
"env_key": "XAI_API_KEY",
},
"medgemma": {
"class": ChatMedGemma,
"env_key": None, # Local model, no API key needed
"is_local": True,
},
# Add more providers with default configurations here
}
@classmethod
def register_provider(cls, prefix: str, model_class: Type[BaseLanguageModel], env_key: str, **kwargs) -> None:
"""Register a new model provider.
Args:
prefix (str): The prefix used to identify this model provider (e.g., 'gpt', 'gemini')
model_class (Type[BaseLanguageModel]): The LangChain model class to use
env_key (str): The environment variable name for the API key
**kwargs: Additional provider-specific configuration
"""
cls._model_providers[prefix] = {"class": model_class, "env_key": env_key, **kwargs}
@classmethod
def create_model(
cls, model_name: str, temperature: float = 0.7, top_p: float = 0.95, max_tokens: int = 5000, **kwargs
) -> BaseLanguageModel:
"""Create and return an instance of the appropriate language model.
Args:
model_name (str): Name of the model to create (e.g., 'gpt-4o', 'gemini-2.5-pro')
temperature (float, optional): Temperature parameter. Defaults to 0.7.
top_p (float, optional): Top-p sampling parameter. Defaults to 0.95.
max_tokens (int, optional): Maximum tokens to generate. Defaults to 5000.
**kwargs: Additional model-specific parameters
Returns:
BaseLanguageModel: An initialized language model instance
Raises:
ValueError: If no provider is found for the given model name
ValueError: If the required API key is missing
"""
# Find the matching provider based on model name prefix
provider_prefix = next((prefix for prefix in cls._model_providers if model_name.startswith(prefix)), None)
if not provider_prefix:
raise ValueError(
f"No provider found for model: {model_name}. "
f"Registered providers are for: {list(cls._model_providers.keys())}"
)
provider = cls._model_providers[provider_prefix]
model_class = provider["class"]
env_key = provider["env_key"]
is_local = provider.get("is_local", False)
# Set up provider-specific kwargs
provider_kwargs = {}
# Handle API key (skip for local models)
if not is_local:
if env_key and env_key in os.environ:
provider_kwargs["api_key"] = os.environ[env_key]
elif env_key:
# Log warning but don't fail - the model class might handle missing API keys differently
print(f"Warning: Environment variable {env_key} not found. Authentication may fail.")
# Check for base_url if applicable
if "base_url_key" in provider:
if provider["base_url_key"] in os.environ:
provider_kwargs["base_url"] = os.environ[provider["base_url_key"]]
elif "default_base_url" in provider:
provider_kwargs["base_url"] = provider["default_base_url"]
# Merge with any additional provider-specific settings from the registry
for k, v in provider.items():
if k not in ["class", "env_key", "base_url_key", "default_base_url"]:
provider_kwargs[k] = v
# Strip the provider prefix from the model name
# For example, 'openrouter-anthropic/claude-sonnet-4' becomes 'anthropic/claude-sonnet-4'
# But for OpenAI models like 'gpt-4o', we keep the full name since 'gpt-' is part of the model name
actual_model_name = model_name
if provider_prefix in ["openrouter"] and model_name.startswith(f"{provider_prefix}-"):
actual_model_name = model_name[len(provider_prefix) + 1 :]
# Handle GPT-5 model (special case for GPT-5 models)
if model_name.startswith("gpt-5"):
return model_class(
model=actual_model_name,
temperature=temperature,
reasoning_effort="high",
max_tokens=max_tokens,
**provider_kwargs,
**kwargs,
)
# Handle MedGemma (local model with different parameter names)
if model_name.startswith("medgemma"):
return model_class(
model_name=actual_model_name,
temperature=temperature,
top_p=top_p,
top_k=kwargs.get("top_k", 64),
max_new_tokens=max_tokens,
device=kwargs.get("device", "cuda"),
load_in_4bit=kwargs.get("load_in_4bit", True),
**provider_kwargs,
)
# Create and return the model instance
return model_class(
model=actual_model_name,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
**provider_kwargs,
**kwargs,
)
@classmethod
def list_providers(cls) -> Dict[str, Dict[str, Any]]:
"""List all registered model providers.
Returns:
Dict[str, Dict[str, Any]]: Dictionary of registered providers and their configurations
"""
# Return a copy to prevent accidental modification
return {k: {kk: vv for kk, vv in v.items() if kk != "class"} for k, v in cls._model_providers.items()}