| | """ |
| | Configuration file for the LLM API. |
| | """ |
| |
|
| | import os |
| | from typing import Optional |
| |
|
| |
|
| | |
| | class ModelConfig: |
| | """Configuration for different model types.""" |
| |
|
| | |
| | LLAMA_MODELS = { |
| | "llama-2-7b-chat": "models/llama-2-7b-chat.Q4_K_M.gguf", |
| | "llama-2-13b-chat": "models/llama-2-13b-chat.Q4_K_M.gguf", |
| | "llama-3-8b": "models/llama-3-8b.Q4_K_M.gguf", |
| | } |
| |
|
| | |
| | PHI_MODELS = { |
| | "phi-1": "microsoft/phi-1", |
| | "phi-1_5": "microsoft/phi-1_5", |
| | "phi-2": "microsoft/phi-2", |
| | "phi-3-mini": "microsoft/phi-3-mini-4k-instruct", |
| | "phi-3-small": "microsoft/phi-3-small-8k-instruct", |
| | "phi-3-medium": "microsoft/phi-3-medium-4k-instruct", |
| | } |
| |
|
| | |
| | TRANSFORMERS_MODELS = { |
| | "dialo-gpt-medium": "microsoft/DialoGPT-medium", |
| | "gpt2": "gpt2", |
| | "gpt2-medium": "gpt2-medium", |
| | } |
| |
|
| | @classmethod |
| | def get_model_path(cls, model_name: str) -> Optional[str]: |
| | """Get the model path for a given model name.""" |
| | |
| | if model_name in cls.LLAMA_MODELS: |
| | return cls.LLAMA_MODELS[model_name] |
| |
|
| | |
| | if model_name in cls.PHI_MODELS: |
| | return cls.PHI_MODELS[model_name] |
| |
|
| | |
| | if model_name in cls.TRANSFORMERS_MODELS: |
| | return cls.TRANSFORMERS_MODELS[model_name] |
| |
|
| | return None |
| |
|
| | @classmethod |
| | def get_model_type(cls, model_name: str) -> str: |
| | """Get the model type for a given model name.""" |
| | if model_name in cls.LLAMA_MODELS: |
| | return "llama_cpp" |
| | elif model_name in cls.PHI_MODELS or model_name in cls.TRANSFORMERS_MODELS: |
| | return "transformers" |
| | else: |
| | return "unknown" |
| |
|
| | @classmethod |
| | def list_models(cls) -> dict: |
| | """List all available models.""" |
| | return { |
| | "llama_models": list(cls.LLAMA_MODELS.keys()), |
| | "phi_models": list(cls.PHI_MODELS.keys()), |
| | "transformers_models": list(cls.TRANSFORMERS_MODELS.keys()), |
| | } |
| |
|
| |
|
| | |
| | class Config: |
| | """Main configuration class.""" |
| |
|
| | |
| | DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "phi-1_5") |
| | MODEL_PATH = os.getenv("MODEL_PATH", "models/llama-2-7b-chat.Q4_K_M.gguf") |
| | TRANSFORMERS_MODEL = os.getenv("TRANSFORMERS_MODEL", "microsoft/phi-1_5") |
| |
|
| | |
| | HOST = os.getenv("HOST", "0.0.0.0") |
| | PORT = int(os.getenv("PORT", "8000")) |
| | DEBUG = os.getenv("DEBUG", "false").lower() == "true" |
| |
|
| | |
| | DEFAULT_MAX_TOKENS = int(os.getenv("DEFAULT_MAX_TOKENS", "2048")) |
| | DEFAULT_TEMPERATURE = float(os.getenv("DEFAULT_TEMPERATURE", "0.7")) |
| | DEFAULT_TOP_P = float(os.getenv("DEFAULT_TOP_P", "0.9")) |
| |
|
| | |
| | LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO") |
| |
|
| | @classmethod |
| | def setup_model_environment(cls, model_name: str): |
| | """Set up environment variables for a specific model.""" |
| | model_path = ModelConfig.get_model_path(model_name) |
| | model_type = ModelConfig.get_model_type(model_name) |
| |
|
| | if model_type == "llama_cpp" and model_path: |
| | os.environ["MODEL_PATH"] = model_path |
| | print(f"✅ Set up LLaMA model: {model_name} -> {model_path}") |
| | elif model_type == "transformers" and model_path: |
| | os.environ["TRANSFORMERS_MODEL"] = model_path |
| | print(f"✅ Set up Transformers model: {model_name} -> {model_path}") |
| | else: |
| | print(f"❌ Unknown model: {model_name}") |
| | return False |
| |
|
| | return True |
| |
|
| |
|
| | |
| | def setup_phi_model(model_name: str = "phi-1_5"): |
| | """Quick setup for Phi models.""" |
| | return Config.setup_model_environment(model_name) |
| |
|
| |
|
| | def setup_llama_model(model_name: str = "llama-2-7b-chat"): |
| | """Quick setup for LLaMA models.""" |
| | return Config.setup_model_environment(model_name) |
| |
|
| |
|
| | def list_available_models(): |
| | """List all available models.""" |
| | return ModelConfig.list_models() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | print("Available Models:") |
| | models = list_available_models() |
| | for category, model_list in models.items(): |
| | print(f"\n{category.replace('_', ' ').title()}:") |
| | for model in model_list: |
| | model_type = ModelConfig.get_model_type(model) |
| | print(f" - {model} ({model_type})") |
| |
|
| | print(f"\nDefault model: {Config.DEFAULT_MODEL}") |
| | print(f"Model path: {Config.MODEL_PATH}") |
| | print(f"Transformers model: {Config.TRANSFORMERS_MODEL}") |
| |
|