File size: 4,059 Bytes
24a7f55
4a0ad64
24a7f55
4a0ad64
24a7f55
 
 
4a0ad64
24a7f55
 
4a0ad64
 
 
24a7f55
 
 
 
 
4a0ad64
 
24a7f55
 
 
4a0ad64
 
 
24a7f55
4a0ad64
24a7f55
4a0ad64
 
 
 
 
 
 
 
 
 
 
 
24a7f55
 
 
 
4a0ad64
 
 
 
 
 
 
 
 
 
 
 
24a7f55
4a0ad64
 
24a7f55
 
 
4a0ad64
 
 
 
24a7f55
4a0ad64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24a7f55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""
Configuration for model management.

This module provides configuration for loading and managing models
from Hugging Face's model hub.
"""
from dataclasses import dataclass
from typing import Dict, List, Optional, Literal
import os

ModelType = Literal["text-generation", "text-embedding", "vision", "multimodal"]
DeviceType = Literal["auto", "cpu", "cuda"]

@dataclass
class ModelConfig:
    """Configuration for a single model."""
    model_id: str
    model_path: str
    model_type: ModelType
    device: DeviceType = "auto"
    quantize: bool = True
    use_safetensors: bool = True
    trust_remote_code: bool = True
    description: str = ""
    size_gb: float = 0.0  # Approximate size in GB
    recommended: bool = False

# Available models with their configurations
DEFAULT_MODELS = {
    # Lightweight models (under 2GB)
    "tiny-llama": ModelConfig(
        model_id="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
        model_path="./models/tiny-llama-1.1b-chat",
        model_type="text-generation",
        quantize=True,
        description="Very small and fast model, good for quick testing",
        size_gb=1.1,
        recommended=True
    ),
    
    # Medium models (2-10GB)
    "mistral-7b": ModelConfig(
        model_id="TheBloke/Mistral-7B-Instruct-v0.1-GPTQ",
        model_path="./models/mistral-7b-instruct-gptq",
        model_type="text-generation",
        quantize=True,
        description="Good balance of performance and resource usage",
        size_gb=4.0,
        recommended=True
    ),
    
    "llama2-7b": ModelConfig(
        model_id="meta-llama/Llama-2-7b-chat-hf",
        model_path="./models/llama2-7b-chat",
        model_type="text-generation",
        description="High quality 7B parameter model from Meta",
        size_gb=13.0
    ),
    
    # Embedding models
    "all-mpnet-base-v2": ModelConfig(
        model_id="sentence-transformers/all-mpnet-base-v2",
        model_path="./models/all-mpnet-base-v2",
        model_type="text-embedding",
        description="General purpose sentence transformer, good balance of speed and quality",
        size_gb=0.4,
        recommended=True
    ),
    
    "bge-small-en": ModelConfig(
        model_id="BAAI/bge-small-en-v1.5",
        model_path="./models/bge-small-en",
        model_type="text-embedding",
        description="Small but powerful embedding model",
        size_gb=0.13
    ),
    
    # Larger models (10GB+)
    "mixtral-8x7b": ModelConfig(
        model_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
        model_path="./models/mixtral-8x7b-instruct",
        model_type="text-generation",
        description="Very powerful 8x7B MoE model, requires significant resources",
        size_gb=85.0
    ),
    
    "llama2-13b": ModelConfig(
        model_id="meta-llama/Llama-2-13b-chat-hf",
        model_path="./models/llama2-13b-chat",
        model_type="text-generation",
        description="High quality 13B parameter model, better reasoning capabilities",
        size_gb=26.0
    )
}

def get_model_config(model_name: str) -> Optional[ModelConfig]:
    """Get configuration for a specific model."""
    return DEFAULT_MODELS.get(model_name)

def list_available_models() -> List[str]:
    """List all available model names."""
    return list(DEFAULT_MODELS.keys())

def get_model_path(model_name: str) -> str:
    """Get the local path for a model, downloading it if necessary."""
    config = get_model_config(model_name)
    if not config:
        raise ValueError(f"Unknown model: {model_name}")
    
    # Create model directory if it doesn't exist
    os.makedirs(config.model_path, exist_ok=True)
    
    # If model files don't exist, download them
    if not os.path.exists(os.path.join(config.model_path, "config.json")):
        from huggingface_hub import snapshot_download
        snapshot_download(
            repo_id=config.model_id,
            local_dir=config.model_path,
            local_dir_use_symlinks=True,
            ignore_patterns=["*.h5", "*.ot", "*.msgpack"],
        )
    
    return config.model_path