import time import numpy as np import numpy.typing as npt from flask import Flask, jsonify, request from huggingface_hub import hf_hub_download from llama_cpp import Llama from llama_cpp.llama_speculative import LlamaDraftModel app = Flask(__name__) PORT = 7860 DRAFT_REPO = "mradermacher/SmolLM2-135M-Instruct-GGUF" DRAFT_FILE = "SmolLM2-135M-Instruct.Q4_K_M.gguf" MAIN_REPO = "mradermacher/SmolLM2-360M-Instruct-GGUF" MAIN_FILE = "SmolLM2-360M-Instruct.Q4_K_M.gguf" class LlamaModelDraft(LlamaDraftModel): """Wrap a smaller Llama GGUF model for speculative decoding.""" def __init__(self, draft: Llama, num_pred_tokens: int = 16): self.draft = draft self.num_pred_tokens = num_pred_tokens def __call__( self, input_ids: npt.NDArray[np.intc], /, **kwargs ) -> npt.NDArray[np.intc]: tokens = input_ids.tolist() if not tokens: return np.array([], dtype=np.intc) self.draft.reset() self.draft.eval(tokens) predicted: list[int] = [] for _ in range(self.num_pred_tokens): token = self.draft.sample(temp=0.0) predicted.append(token) self.draft.eval([token]) return np.array(predicted, dtype=np.intc) print("Downloading draft model...") draft_path = hf_hub_download(repo_id=DRAFT_REPO, filename=DRAFT_FILE) print("Downloading main model...") main_path = hf_hub_download(repo_id=MAIN_REPO, filename=MAIN_FILE) print("Loading models into memory...") draft_model = Llama( model_path=draft_path, n_ctx=2048, n_batch=512, n_threads=2, verbose=False, ) main_model = Llama( model_path=main_path, n_ctx=2048, n_batch=512, n_threads=2, draft_model=LlamaModelDraft(draft_model, num_pred_tokens=16), verbose=False, ) print("Models successfully loaded!") @app.route("/v1/chat/completions", methods=["POST"]) def chat_completions(): data = request.json or {} if "messages" not in data: return jsonify({"error": "Missing messages array"}), 400 start_time = time.time() response = main_model.create_chat_completion( messages=data.get("messages", []), temperature=0.7, max_tokens=data.get("max_tokens", 512), stream=False, ) generation_time = time.time() - start_time completion_tokens = response["usage"]["completion_tokens"] tps = completion_tokens / generation_time if generation_time > 0 else 0 response["system_performance"] = { "tokens_per_second": round(tps, 2), "generation_time_sec": round(generation_time, 2), "acceleration_technique": "Speculative Decoding (135M draft + 360M main)", } return jsonify(response) if __name__ == "__main__": app.run(host="0.0.0.0", port=PORT)