File size: 5,160 Bytes
2a72045
3274ec4
2a72045
 
 
 
 
 
3274ec4
9d2777a
 
 
 
 
 
3274ec4
 
 
 
 
 
 
 
 
 
 
 
2a72045
 
 
264847d
 
2a72045
264847d
 
3274ec4
264847d
2a72045
3274ec4
9d2777a
2a72045
3274ec4
2a72045
3274ec4
 
 
 
2a72045
3274ec4
 
 
 
9d2777a
3274ec4
 
 
 
 
 
 
9d2777a
3274ec4
 
 
9d2777a
3274ec4
2a72045
 
 
 
 
3274ec4
 
 
 
 
 
 
 
2a72045
 
3274ec4
2a72045
3274ec4
 
2a72045
3274ec4
 
 
 
 
 
 
 
 
 
 
 
 
 
2a72045
 
3274ec4
2a72045
3274ec4
9d2777a
 
3274ec4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a72045
64f495c
 
3274ec4
 
 
64f495c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
from llama_cpp import Llama
import requests
from typing import Generator

class ModelManager:
    def __init__(self):
        self.models = {}
        self.model_configs = {
            "fast-chat": {
                "repo": "Qwen/Qwen2.5-0.5B-Instruct-GGUF",
                "file": "qwen2.5-0.5b-instruct-q4_k_m.gguf",
                "url": "https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct-GGUF/resolve/main/qwen2.5-0.5b-instruct-q4_k_m.gguf",
                "format": "chatml"
            },
            "tinyllama": {
                "repo": "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
                "file": "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf",
                "url": "https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF/resolve/main/tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf",
                "format": "tinyllama"
            },
            "coder": {
                "repo": "Qwen/Qwen2.5-Coder-1.5B-Instruct-GGUF",
                "file": "qwen2.5-coder-1.5b-instruct-q4_k_m.gguf",
                "url": "https://huggingface.co/Qwen/Qwen2.5-Coder-1.5B-Instruct-GGUF/resolve/main/qwen2.5-coder-1.5b-instruct-q4_k_m.gguf",
                "format": "chatml"
            }
        }
        self.models_dir = os.path.join(os.getcwd(), "models")
        os.makedirs(self.models_dir, exist_ok=True)
        self.critical_models = ["fast-chat"]
        self.auto_download_critical()

    def auto_download_critical(self):
        """Download only critical lightweight models at startup"""
        print("Checking for pre-downloaded models...")
        for model_id in self.critical_models:
            try:
                path = self.download_model(model_id)
                print(f"✓ {model_id} ready")
            except Exception as e:
                print(f"✗ Failed to ensure {model_id}: {e}")

    def download_model(self, model_id: str):
        config = self.model_configs.get(model_id)
        if not config:
            raise ValueError(f"Model {model_id} not configured")
        
        target_path = os.path.join(self.models_dir, config["file"])
        if os.path.exists(target_path) and os.path.getsize(target_path) > 50000000:
            return target_path

        print(f"Downloading {model_id}...")
        try:
            response = requests.get(config["url"], stream=True, timeout=60)
            response.raise_for_status()
            with open(target_path, "wb") as f:
                for chunk in response.iter_content(chunk_size=1024*1024):
                    if chunk:
                        f.write(chunk)
            print(f"✓ {model_id} downloaded")
            return target_path
        except Exception as e:
            if os.path.exists(target_path):
                os.remove(target_path)
            raise e

    def load_model(self, model_id: str):
        if model_id in self.models:
            return self.models[model_id]
        
        path = self.download_model(model_id)
        self.models[model_id] = Llama(
            model_path=path,
            n_ctx=1024,
            n_threads=2,
            verbose=False
        )
        return self.models[model_id]

    def format_prompt(self, model_id: str, system: str, history: list, prompt: str):
        fmt = self.model_configs[model_id]["format"]
        
        if fmt == "chatml":
            full = f"<|im_start|>system\n{system}<|im_end|>\n"
            for msg in history:
                role = "user" if msg["role"] == "user" else "assistant"
                full += f"<|im_start|>{role}\n{msg['content']}<|im_end|>\n"
            full += f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
            return full, ["<|im_end|>", "###", "<|im_start|>", "</s>"]

        elif fmt == "tinyllama":
            full = f"<|system|>\n{system}</s>\n"
            for msg in history:
                role = "user" if msg["role"] == "user" else "assistant"
                full += f"<|{role}|>\n{msg['content']}</s>\n"
            full += f"<|user|>\n{prompt}</s>\n<|assistant|>\n"
            return full, ["</s>", "<|user|>", "<|assistant|>"]

        return prompt, ["</s>"]

    def generate_stream(self, model_id: str, prompt: str, context: list = None, **kwargs) -> Generator[str, None, None]:
        llm = self.load_model(model_id)
        
        system_text = (
            "You are a helpful AI assistant. "
            "For math, use LaTeX with $ $ for display and \\( \\) for inline."
        )
        
        full_prompt, stop_tokens = self.format_prompt(model_id, system_text, context or [], prompt)
        
        params = {
            "max_tokens": kwargs.get("max_tokens", 512),
            "stop": stop_tokens,
            "stream": True,
            "temperature": kwargs.get("temperature", 0.7),
            "top_p": kwargs.get("top_p", 0.95)
        }
        
        for output in llm(full_prompt, **params):
            token = output["choices"][0]["text"]
            yield token

    def cleanup(self):
        """Cleanup resources"""
        for model in self.models.values():
            if hasattr(model, 'close'):
                model.close()
        self.models.clear()