File size: 4,661 Bytes
c2f9396 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 | """
Configuration file for the LLM API.
"""
import os
from typing import Optional
# Model Configuration
class ModelConfig:
"""Configuration for different model types."""
# LLaMA Models (GGUF format)
LLAMA_MODELS = {
"llama-2-7b-chat": "models/llama-2-7b-chat.Q4_K_M.gguf",
"llama-2-13b-chat": "models/llama-2-13b-chat.Q4_K_M.gguf",
"llama-3-8b": "models/llama-3-8b.Q4_K_M.gguf",
}
# Microsoft Phi Models (Transformers)
PHI_MODELS = {
"phi-1": "microsoft/phi-1",
"phi-1_5": "microsoft/phi-1_5",
"phi-2": "microsoft/phi-2",
"phi-3-mini": "microsoft/phi-3-mini-4k-instruct",
"phi-3-small": "microsoft/phi-3-small-8k-instruct",
"phi-3-medium": "microsoft/phi-3-medium-4k-instruct",
}
# Other Transformers Models
TRANSFORMERS_MODELS = {
"dialo-gpt-medium": "microsoft/DialoGPT-medium",
"gpt2": "gpt2",
"gpt2-medium": "gpt2-medium",
}
@classmethod
def get_model_path(cls, model_name: str) -> Optional[str]:
"""Get the model path for a given model name."""
# Check LLaMA models first
if model_name in cls.LLAMA_MODELS:
return cls.LLAMA_MODELS[model_name]
# Check Phi models
if model_name in cls.PHI_MODELS:
return cls.PHI_MODELS[model_name]
# Check other transformers models
if model_name in cls.TRANSFORMERS_MODELS:
return cls.TRANSFORMERS_MODELS[model_name]
return None
@classmethod
def get_model_type(cls, model_name: str) -> str:
"""Get the model type for a given model name."""
if model_name in cls.LLAMA_MODELS:
return "llama_cpp"
elif model_name in cls.PHI_MODELS or model_name in cls.TRANSFORMERS_MODELS:
return "transformers"
else:
return "unknown"
@classmethod
def list_models(cls) -> dict:
"""List all available models."""
return {
"llama_models": list(cls.LLAMA_MODELS.keys()),
"phi_models": list(cls.PHI_MODELS.keys()),
"transformers_models": list(cls.TRANSFORMERS_MODELS.keys()),
}
# Environment Configuration
class Config:
"""Main configuration class."""
# Model settings
DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "phi-1_5")
MODEL_PATH = os.getenv("MODEL_PATH", "models/llama-2-7b-chat.Q4_K_M.gguf")
TRANSFORMERS_MODEL = os.getenv("TRANSFORMERS_MODEL", "microsoft/phi-1_5")
# API settings
HOST = os.getenv("HOST", "0.0.0.0")
PORT = int(os.getenv("PORT", "8000"))
DEBUG = os.getenv("DEBUG", "false").lower() == "true"
# Model parameters
DEFAULT_MAX_TOKENS = int(os.getenv("DEFAULT_MAX_TOKENS", "2048"))
DEFAULT_TEMPERATURE = float(os.getenv("DEFAULT_TEMPERATURE", "0.7"))
DEFAULT_TOP_P = float(os.getenv("DEFAULT_TOP_P", "0.9"))
# Logging
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO")
@classmethod
def setup_model_environment(cls, model_name: str):
"""Set up environment variables for a specific model."""
model_path = ModelConfig.get_model_path(model_name)
model_type = ModelConfig.get_model_type(model_name)
if model_type == "llama_cpp" and model_path:
os.environ["MODEL_PATH"] = model_path
print(f"✅ Set up LLaMA model: {model_name} -> {model_path}")
elif model_type == "transformers" and model_path:
os.environ["TRANSFORMERS_MODEL"] = model_path
print(f"✅ Set up Transformers model: {model_name} -> {model_path}")
else:
print(f"❌ Unknown model: {model_name}")
return False
return True
# Convenience functions
def setup_phi_model(model_name: str = "phi-1_5"):
"""Quick setup for Phi models."""
return Config.setup_model_environment(model_name)
def setup_llama_model(model_name: str = "llama-2-7b-chat"):
"""Quick setup for LLaMA models."""
return Config.setup_model_environment(model_name)
def list_available_models():
"""List all available models."""
return ModelConfig.list_models()
if __name__ == "__main__":
# Example usage
print("Available Models:")
models = list_available_models()
for category, model_list in models.items():
print(f"\n{category.replace('_', ' ').title()}:")
for model in model_list:
model_type = ModelConfig.get_model_type(model)
print(f" - {model} ({model_type})")
print(f"\nDefault model: {Config.DEFAULT_MODEL}")
print(f"Model path: {Config.MODEL_PATH}")
print(f"Transformers model: {Config.TRANSFORMERS_MODEL}")
|