File size: 4,661 Bytes
c2f9396
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""
Configuration file for the LLM API.
"""

import os
from typing import Optional


# Model Configuration
class ModelConfig:
    """Configuration for different model types."""

    # LLaMA Models (GGUF format)
    LLAMA_MODELS = {
        "llama-2-7b-chat": "models/llama-2-7b-chat.Q4_K_M.gguf",
        "llama-2-13b-chat": "models/llama-2-13b-chat.Q4_K_M.gguf",
        "llama-3-8b": "models/llama-3-8b.Q4_K_M.gguf",
    }

    # Microsoft Phi Models (Transformers)
    PHI_MODELS = {
        "phi-1": "microsoft/phi-1",
        "phi-1_5": "microsoft/phi-1_5",
        "phi-2": "microsoft/phi-2",
        "phi-3-mini": "microsoft/phi-3-mini-4k-instruct",
        "phi-3-small": "microsoft/phi-3-small-8k-instruct",
        "phi-3-medium": "microsoft/phi-3-medium-4k-instruct",
    }

    # Other Transformers Models
    TRANSFORMERS_MODELS = {
        "dialo-gpt-medium": "microsoft/DialoGPT-medium",
        "gpt2": "gpt2",
        "gpt2-medium": "gpt2-medium",
    }

    @classmethod
    def get_model_path(cls, model_name: str) -> Optional[str]:
        """Get the model path for a given model name."""
        # Check LLaMA models first
        if model_name in cls.LLAMA_MODELS:
            return cls.LLAMA_MODELS[model_name]

        # Check Phi models
        if model_name in cls.PHI_MODELS:
            return cls.PHI_MODELS[model_name]

        # Check other transformers models
        if model_name in cls.TRANSFORMERS_MODELS:
            return cls.TRANSFORMERS_MODELS[model_name]

        return None

    @classmethod
    def get_model_type(cls, model_name: str) -> str:
        """Get the model type for a given model name."""
        if model_name in cls.LLAMA_MODELS:
            return "llama_cpp"
        elif model_name in cls.PHI_MODELS or model_name in cls.TRANSFORMERS_MODELS:
            return "transformers"
        else:
            return "unknown"

    @classmethod
    def list_models(cls) -> dict:
        """List all available models."""
        return {
            "llama_models": list(cls.LLAMA_MODELS.keys()),
            "phi_models": list(cls.PHI_MODELS.keys()),
            "transformers_models": list(cls.TRANSFORMERS_MODELS.keys()),
        }


# Environment Configuration
class Config:
    """Main configuration class."""

    # Model settings
    DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "phi-1_5")
    MODEL_PATH = os.getenv("MODEL_PATH", "models/llama-2-7b-chat.Q4_K_M.gguf")
    TRANSFORMERS_MODEL = os.getenv("TRANSFORMERS_MODEL", "microsoft/phi-1_5")

    # API settings
    HOST = os.getenv("HOST", "0.0.0.0")
    PORT = int(os.getenv("PORT", "8000"))
    DEBUG = os.getenv("DEBUG", "false").lower() == "true"

    # Model parameters
    DEFAULT_MAX_TOKENS = int(os.getenv("DEFAULT_MAX_TOKENS", "2048"))
    DEFAULT_TEMPERATURE = float(os.getenv("DEFAULT_TEMPERATURE", "0.7"))
    DEFAULT_TOP_P = float(os.getenv("DEFAULT_TOP_P", "0.9"))

    # Logging
    LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO")

    @classmethod
    def setup_model_environment(cls, model_name: str):
        """Set up environment variables for a specific model."""
        model_path = ModelConfig.get_model_path(model_name)
        model_type = ModelConfig.get_model_type(model_name)

        if model_type == "llama_cpp" and model_path:
            os.environ["MODEL_PATH"] = model_path
            print(f"✅ Set up LLaMA model: {model_name} -> {model_path}")
        elif model_type == "transformers" and model_path:
            os.environ["TRANSFORMERS_MODEL"] = model_path
            print(f"✅ Set up Transformers model: {model_name} -> {model_path}")
        else:
            print(f"❌ Unknown model: {model_name}")
            return False

        return True


# Convenience functions
def setup_phi_model(model_name: str = "phi-1_5"):
    """Quick setup for Phi models."""
    return Config.setup_model_environment(model_name)


def setup_llama_model(model_name: str = "llama-2-7b-chat"):
    """Quick setup for LLaMA models."""
    return Config.setup_model_environment(model_name)


def list_available_models():
    """List all available models."""
    return ModelConfig.list_models()


if __name__ == "__main__":
    # Example usage
    print("Available Models:")
    models = list_available_models()
    for category, model_list in models.items():
        print(f"\n{category.replace('_', ' ').title()}:")
        for model in model_list:
            model_type = ModelConfig.get_model_type(model)
            print(f"  - {model} ({model_type})")

    print(f"\nDefault model: {Config.DEFAULT_MODEL}")
    print(f"Model path: {Config.MODEL_PATH}")
    print(f"Transformers model: {Config.TRANSFORMERS_MODEL}")