medassist-protocol / utils /llm_factory.py
HARSHITA240's picture
Medassist: inital commit
f65c848
"""
LLM Factory - Centralized LLM provider management
Supports: Ollama, HuggingFace, Together AI, Groq, and more
"""
from typing import Optional, Dict
import os
from langchain_community.llms import Ollama
from langchain_huggingface import HuggingFaceEndpoint
from langchain_community.chat_models import ChatOllama
from langchain_core.callbacks.manager import CallbackManager
from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
class LLMFactory:
"""Factory for creating different LLM providers"""
# Model configurations
MODELS = {
"ollama": {
"llama3.1": {"model": "llama3.1:8b", "context": 128000},
"mistral": {"model": "mistral:7b", "context": 32000},
"mixtral": {"model": "mixtral:8x7b", "context": 32000},
"meditron": {"model": "meditron:7b", "context": 4096}, # Medical-specific
"biomistral": {"model": "biomistral:7b", "context": 4096} # Medical-specific
},
"huggingface": {
"mistral-7b": "mistralai/Mistral-7B-Instruct-v0.2",
"zephyr-7b": "HuggingFaceH4/zephyr-7b-beta",
"meditron-7b": "epfl-llm/meditron-7b", # Medical-specific
"biomistral-7b": "BioMistral/BioMistral-7B" # Medical-specific
},
"together": {
"llama-3.1-8b": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
"mistral-7b": "mistralai/Mistral-7B-Instruct-v0.2",
"mixtral-8x7b": "mistralai/Mixtral-8x7B-Instruct-v0.1"
},
"groq": {
"llama-3.1-8b": "llama-3.1-8b-instant",
"mixtral-8x7b": "mixtral-8x7b-32768"
}
}
@classmethod
def create_llm(
cls,
provider: str,
model_name: str,
temperature: float = 0.1,
max_tokens: int = 2048,
api_key: Optional[str] = None,
base_url: Optional[str] = None
):
"""
Create LLM instance based on provider
Args:
provider: 'ollama', 'huggingface', 'together', 'groq'
model_name: Model identifier
temperature: Sampling temperature
max_tokens: Maximum tokens to generate
api_key: API key (if needed)
base_url: Custom endpoint URL (for Ollama)
"""
if provider == "ollama":
return cls._create_ollama(model_name, temperature, max_tokens, base_url)
elif provider == "huggingface":
return cls._create_huggingface(model_name, temperature, max_tokens, api_key)
elif provider == "together":
return cls._create_together(model_name, temperature, max_tokens, api_key)
elif provider == "groq":
return cls._create_groq(model_name, temperature, max_tokens, api_key)
else:
raise ValueError(f"Unknown provider: {provider}")
@classmethod
def _create_ollama(cls, model_name: str, temperature: float, max_tokens: int, base_url: Optional[str]):
"""Create Ollama LLM instance"""
model_config = cls.MODELS["ollama"].get(model_name, {})
actual_model = model_config.get("model", model_name)
# Use custom base_url if provided (for remote Ollama)
# Otherwise use default localhost
ollama_base_url = base_url or "http://localhost:11434"
return ChatOllama(
model=actual_model,
temperature=temperature,
num_predict=max_tokens,
base_url=ollama_base_url,
callback_manager=CallbackManager([StreamingStdOutCallbackHandler()])
)
@classmethod
def _create_huggingface(cls, model_name: str, temperature: float, max_tokens: int, api_key: Optional[str]):
"""Create HuggingFace LLM instance"""
# Get full model path
model_path = cls.MODELS["huggingface"].get(model_name, model_name)
# Use HF token from environment or parameter
hf_token = api_key or os.environ.get("HUGGINGFACE_API_KEY")
return HuggingFaceEndpoint(
repo_id=model_path,
temperature=temperature,
max_new_tokens=max_tokens,
huggingfacehub_api_token=hf_token,
task="text-generation"
)
@classmethod
def _create_groq(cls, model_name: str, temperature: float, max_tokens: int, api_key: Optional[str]):
"""Create Groq LLM instance"""
from langchain_groq import ChatGroq
model_path = cls.MODELS["groq"].get(model_name, model_name)
groq_api_key = api_key or os.environ.get("GROQ_API_KEY")
return ChatGroq(
model=model_path,
temperature=temperature,
max_tokens=max_tokens,
groq_api_key=groq_api_key
)
@classmethod
def get_available_models(cls, provider: str) -> Dict:
"""Get list of available models for a provider"""
return cls.MODELS.get(provider, {})
@classmethod
def is_medical_model(cls, model_name: str) -> bool:
"""Check if model is medical-specific"""
medical_keywords = ['meditron', 'biomistral', 'medical', 'clinical', 'bio']
return any(keyword in model_name.lower() for keyword in medical_keywords)