| import json |
|
|
| from huggingface_hub import hf_hub_download |
| from llama_cpp import Llama |
|
|
| from config import CONTEXT_SIZE, FILENAME, N_GPU_LAYERS, N_THREADS, REPO_ID |
|
|
|
|
| class ModelEngine: |
| def __init__(self): |
| self.llm = None |
| self._load_model() |
|
|
| def _load_model(self): |
| print(f"Loading model {REPO_ID}...") |
| try: |
| model_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME) |
| self.llm = Llama( |
| model_path=model_path, |
| n_ctx=CONTEXT_SIZE, |
| n_threads=N_THREADS, |
| n_gpu_layers=N_GPU_LAYERS, |
| n_batch=512, |
| verbose=True, |
| ) |
| print("Model loaded successfully.") |
| except Exception as e: |
| print(f"CRITICAL ERROR: Failed to load model. {e}") |
| self.llm = None |
|
|
| def generate(self, messages, max_tokens, temperature, stream=True): |
| if not self.llm: |
| raise RuntimeError("Model is not loaded.") |
|
|
| return self.llm.create_chat_completion( |
| messages=messages, |
| max_tokens=int(max_tokens), |
| temperature=float(temperature), |
| stream=stream, |
| ) |
|
|
|
|
| |
| engine = ModelEngine() |
|
|