Spaces:
Runtime error
Runtime error
| import time | |
| import numpy as np | |
| import numpy.typing as npt | |
| from flask import Flask, jsonify, request | |
| from huggingface_hub import hf_hub_download | |
| from llama_cpp import Llama | |
| from llama_cpp.llama_speculative import LlamaDraftModel | |
| app = Flask(__name__) | |
| PORT = 7860 | |
| DRAFT_REPO = "mradermacher/SmolLM2-135M-Instruct-GGUF" | |
| DRAFT_FILE = "SmolLM2-135M-Instruct.Q4_K_M.gguf" | |
| MAIN_REPO = "mradermacher/SmolLM2-360M-Instruct-GGUF" | |
| MAIN_FILE = "SmolLM2-360M-Instruct.Q4_K_M.gguf" | |
| class LlamaModelDraft(LlamaDraftModel): | |
| """Wrap a smaller Llama GGUF model for speculative decoding.""" | |
| def __init__(self, draft: Llama, num_pred_tokens: int = 16): | |
| self.draft = draft | |
| self.num_pred_tokens = num_pred_tokens | |
| def __call__( | |
| self, input_ids: npt.NDArray[np.intc], /, **kwargs | |
| ) -> npt.NDArray[np.intc]: | |
| tokens = input_ids.tolist() | |
| if not tokens: | |
| return np.array([], dtype=np.intc) | |
| self.draft.reset() | |
| self.draft.eval(tokens) | |
| predicted: list[int] = [] | |
| for _ in range(self.num_pred_tokens): | |
| token = self.draft.sample(temp=0.0) | |
| predicted.append(token) | |
| self.draft.eval([token]) | |
| return np.array(predicted, dtype=np.intc) | |
| print("Downloading draft model...") | |
| draft_path = hf_hub_download(repo_id=DRAFT_REPO, filename=DRAFT_FILE) | |
| print("Downloading main model...") | |
| main_path = hf_hub_download(repo_id=MAIN_REPO, filename=MAIN_FILE) | |
| print("Loading models into memory...") | |
| draft_model = Llama( | |
| model_path=draft_path, | |
| n_ctx=2048, | |
| n_batch=512, | |
| n_threads=2, | |
| verbose=False, | |
| ) | |
| main_model = Llama( | |
| model_path=main_path, | |
| n_ctx=2048, | |
| n_batch=512, | |
| n_threads=2, | |
| draft_model=LlamaModelDraft(draft_model, num_pred_tokens=16), | |
| verbose=False, | |
| ) | |
| print("Models successfully loaded!") | |
| def chat_completions(): | |
| data = request.json or {} | |
| if "messages" not in data: | |
| return jsonify({"error": "Missing messages array"}), 400 | |
| start_time = time.time() | |
| response = main_model.create_chat_completion( | |
| messages=data.get("messages", []), | |
| temperature=0.7, | |
| max_tokens=data.get("max_tokens", 512), | |
| stream=False, | |
| ) | |
| generation_time = time.time() - start_time | |
| completion_tokens = response["usage"]["completion_tokens"] | |
| tps = completion_tokens / generation_time if generation_time > 0 else 0 | |
| response["system_performance"] = { | |
| "tokens_per_second": round(tps, 2), | |
| "generation_time_sec": round(generation_time, 2), | |
| "acceleration_technique": "Speculative Decoding (135M draft + 360M main)", | |
| } | |
| return jsonify(response) | |
| if __name__ == "__main__": | |
| app.run(host="0.0.0.0", port=PORT) | |