chat-bot / config.py
surahj's picture
Initial commit: LLM Chat Interface for HF Spaces
c2f9396
"""
Configuration file for the LLM API.
"""
import os
from typing import Optional
# Model Configuration
class ModelConfig:
"""Configuration for different model types."""
# LLaMA Models (GGUF format)
LLAMA_MODELS = {
"llama-2-7b-chat": "models/llama-2-7b-chat.Q4_K_M.gguf",
"llama-2-13b-chat": "models/llama-2-13b-chat.Q4_K_M.gguf",
"llama-3-8b": "models/llama-3-8b.Q4_K_M.gguf",
}
# Microsoft Phi Models (Transformers)
PHI_MODELS = {
"phi-1": "microsoft/phi-1",
"phi-1_5": "microsoft/phi-1_5",
"phi-2": "microsoft/phi-2",
"phi-3-mini": "microsoft/phi-3-mini-4k-instruct",
"phi-3-small": "microsoft/phi-3-small-8k-instruct",
"phi-3-medium": "microsoft/phi-3-medium-4k-instruct",
}
# Other Transformers Models
TRANSFORMERS_MODELS = {
"dialo-gpt-medium": "microsoft/DialoGPT-medium",
"gpt2": "gpt2",
"gpt2-medium": "gpt2-medium",
}
@classmethod
def get_model_path(cls, model_name: str) -> Optional[str]:
"""Get the model path for a given model name."""
# Check LLaMA models first
if model_name in cls.LLAMA_MODELS:
return cls.LLAMA_MODELS[model_name]
# Check Phi models
if model_name in cls.PHI_MODELS:
return cls.PHI_MODELS[model_name]
# Check other transformers models
if model_name in cls.TRANSFORMERS_MODELS:
return cls.TRANSFORMERS_MODELS[model_name]
return None
@classmethod
def get_model_type(cls, model_name: str) -> str:
"""Get the model type for a given model name."""
if model_name in cls.LLAMA_MODELS:
return "llama_cpp"
elif model_name in cls.PHI_MODELS or model_name in cls.TRANSFORMERS_MODELS:
return "transformers"
else:
return "unknown"
@classmethod
def list_models(cls) -> dict:
"""List all available models."""
return {
"llama_models": list(cls.LLAMA_MODELS.keys()),
"phi_models": list(cls.PHI_MODELS.keys()),
"transformers_models": list(cls.TRANSFORMERS_MODELS.keys()),
}
# Environment Configuration
class Config:
"""Main configuration class."""
# Model settings
DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "phi-1_5")
MODEL_PATH = os.getenv("MODEL_PATH", "models/llama-2-7b-chat.Q4_K_M.gguf")
TRANSFORMERS_MODEL = os.getenv("TRANSFORMERS_MODEL", "microsoft/phi-1_5")
# API settings
HOST = os.getenv("HOST", "0.0.0.0")
PORT = int(os.getenv("PORT", "8000"))
DEBUG = os.getenv("DEBUG", "false").lower() == "true"
# Model parameters
DEFAULT_MAX_TOKENS = int(os.getenv("DEFAULT_MAX_TOKENS", "2048"))
DEFAULT_TEMPERATURE = float(os.getenv("DEFAULT_TEMPERATURE", "0.7"))
DEFAULT_TOP_P = float(os.getenv("DEFAULT_TOP_P", "0.9"))
# Logging
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO")
@classmethod
def setup_model_environment(cls, model_name: str):
"""Set up environment variables for a specific model."""
model_path = ModelConfig.get_model_path(model_name)
model_type = ModelConfig.get_model_type(model_name)
if model_type == "llama_cpp" and model_path:
os.environ["MODEL_PATH"] = model_path
print(f"✅ Set up LLaMA model: {model_name} -> {model_path}")
elif model_type == "transformers" and model_path:
os.environ["TRANSFORMERS_MODEL"] = model_path
print(f"✅ Set up Transformers model: {model_name} -> {model_path}")
else:
print(f"❌ Unknown model: {model_name}")
return False
return True
# Convenience functions
def setup_phi_model(model_name: str = "phi-1_5"):
"""Quick setup for Phi models."""
return Config.setup_model_environment(model_name)
def setup_llama_model(model_name: str = "llama-2-7b-chat"):
"""Quick setup for LLaMA models."""
return Config.setup_model_environment(model_name)
def list_available_models():
"""List all available models."""
return ModelConfig.list_models()
if __name__ == "__main__":
# Example usage
print("Available Models:")
models = list_available_models()
for category, model_list in models.items():
print(f"\n{category.replace('_', ' ').title()}:")
for model in model_list:
model_type = ModelConfig.get_model_type(model)
print(f" - {model} ({model_type})")
print(f"\nDefault model: {Config.DEFAULT_MODEL}")
print(f"Model path: {Config.MODEL_PATH}")
print(f"Transformers model: {Config.TRANSFORMERS_MODEL}")