Spaces:
Sleeping
Sleeping
| """ | |
| LLM Factory - Centralized LLM provider management | |
| Supports: Ollama, HuggingFace, Together AI, Groq, and more | |
| """ | |
| from typing import Optional, Dict | |
| import os | |
| from langchain_community.llms import Ollama | |
| from langchain_huggingface import HuggingFaceEndpoint | |
| from langchain_community.chat_models import ChatOllama | |
| from langchain_core.callbacks.manager import CallbackManager | |
| from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler | |
| class LLMFactory: | |
| """Factory for creating different LLM providers""" | |
| # Model configurations | |
| MODELS = { | |
| "ollama": { | |
| "llama3.1": {"model": "llama3.1:8b", "context": 128000}, | |
| "mistral": {"model": "mistral:7b", "context": 32000}, | |
| "mixtral": {"model": "mixtral:8x7b", "context": 32000}, | |
| "meditron": {"model": "meditron:7b", "context": 4096}, # Medical-specific | |
| "biomistral": {"model": "biomistral:7b", "context": 4096} # Medical-specific | |
| }, | |
| "huggingface": { | |
| "mistral-7b": "mistralai/Mistral-7B-Instruct-v0.2", | |
| "zephyr-7b": "HuggingFaceH4/zephyr-7b-beta", | |
| "meditron-7b": "epfl-llm/meditron-7b", # Medical-specific | |
| "biomistral-7b": "BioMistral/BioMistral-7B" # Medical-specific | |
| }, | |
| "together": { | |
| "llama-3.1-8b": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", | |
| "mistral-7b": "mistralai/Mistral-7B-Instruct-v0.2", | |
| "mixtral-8x7b": "mistralai/Mixtral-8x7B-Instruct-v0.1" | |
| }, | |
| "groq": { | |
| "llama-3.1-8b": "llama-3.1-8b-instant", | |
| "mixtral-8x7b": "mixtral-8x7b-32768" | |
| } | |
| } | |
| def create_llm( | |
| cls, | |
| provider: str, | |
| model_name: str, | |
| temperature: float = 0.1, | |
| max_tokens: int = 2048, | |
| api_key: Optional[str] = None, | |
| base_url: Optional[str] = None | |
| ): | |
| """ | |
| Create LLM instance based on provider | |
| Args: | |
| provider: 'ollama', 'huggingface', 'together', 'groq' | |
| model_name: Model identifier | |
| temperature: Sampling temperature | |
| max_tokens: Maximum tokens to generate | |
| api_key: API key (if needed) | |
| base_url: Custom endpoint URL (for Ollama) | |
| """ | |
| if provider == "ollama": | |
| return cls._create_ollama(model_name, temperature, max_tokens, base_url) | |
| elif provider == "huggingface": | |
| return cls._create_huggingface(model_name, temperature, max_tokens, api_key) | |
| elif provider == "together": | |
| return cls._create_together(model_name, temperature, max_tokens, api_key) | |
| elif provider == "groq": | |
| return cls._create_groq(model_name, temperature, max_tokens, api_key) | |
| else: | |
| raise ValueError(f"Unknown provider: {provider}") | |
| def _create_ollama(cls, model_name: str, temperature: float, max_tokens: int, base_url: Optional[str]): | |
| """Create Ollama LLM instance""" | |
| model_config = cls.MODELS["ollama"].get(model_name, {}) | |
| actual_model = model_config.get("model", model_name) | |
| # Use custom base_url if provided (for remote Ollama) | |
| # Otherwise use default localhost | |
| ollama_base_url = base_url or "http://localhost:11434" | |
| return ChatOllama( | |
| model=actual_model, | |
| temperature=temperature, | |
| num_predict=max_tokens, | |
| base_url=ollama_base_url, | |
| callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]) | |
| ) | |
| def _create_huggingface(cls, model_name: str, temperature: float, max_tokens: int, api_key: Optional[str]): | |
| """Create HuggingFace LLM instance""" | |
| # Get full model path | |
| model_path = cls.MODELS["huggingface"].get(model_name, model_name) | |
| # Use HF token from environment or parameter | |
| hf_token = api_key or os.environ.get("HUGGINGFACE_API_KEY") | |
| return HuggingFaceEndpoint( | |
| repo_id=model_path, | |
| temperature=temperature, | |
| max_new_tokens=max_tokens, | |
| huggingfacehub_api_token=hf_token, | |
| task="text-generation" | |
| ) | |
| def _create_groq(cls, model_name: str, temperature: float, max_tokens: int, api_key: Optional[str]): | |
| """Create Groq LLM instance""" | |
| from langchain_groq import ChatGroq | |
| model_path = cls.MODELS["groq"].get(model_name, model_name) | |
| groq_api_key = api_key or os.environ.get("GROQ_API_KEY") | |
| return ChatGroq( | |
| model=model_path, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| groq_api_key=groq_api_key | |
| ) | |
| def get_available_models(cls, provider: str) -> Dict: | |
| """Get list of available models for a provider""" | |
| return cls.MODELS.get(provider, {}) | |
| def is_medical_model(cls, model_name: str) -> bool: | |
| """Check if model is medical-specific""" | |
| medical_keywords = ['meditron', 'biomistral', 'medical', 'clinical', 'bio'] | |
| return any(keyword in model_name.lower() for keyword in medical_keywords) |