File size: 2,786 Bytes
8710de9
9fb99fe
 
 
 
8d824b6
8710de9
9fb99fe
8710de9
 
9fb99fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8710de9
bda2055
9fb99fe
8d824b6
bda2055
9fb99fe
8d824b6
 
 
 
7d948f6
8d824b6
0a6e604
9fb99fe
0a6e604
 
8d824b6
 
0a6e604
8d824b6
0a6e604
9fb99fe
 
8710de9
8d824b6
8710de9
9fb99fe
 
8710de9
7d948f6
9fb99fe
7d948f6
9fb99fe
8710de9
9fb99fe
0a6e604
9fb99fe
0a6e604
9fb99fe
e102131
7d948f6
9fb99fe
e102131
9fb99fe
 
 
 
8710de9
0a6e604
9fb99fe
8710de9
9fb99fe
8710de9
 
9fb99fe
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import time

import numpy as np
import numpy.typing as npt
from flask import Flask, jsonify, request
from huggingface_hub import hf_hub_download
from llama_cpp import Llama
from llama_cpp.llama_speculative import LlamaDraftModel

app = Flask(__name__)
PORT = 7860

DRAFT_REPO = "mradermacher/SmolLM2-135M-Instruct-GGUF"
DRAFT_FILE = "SmolLM2-135M-Instruct.Q4_K_M.gguf"
MAIN_REPO = "mradermacher/SmolLM2-360M-Instruct-GGUF"
MAIN_FILE = "SmolLM2-360M-Instruct.Q4_K_M.gguf"


class LlamaModelDraft(LlamaDraftModel):
    """Wrap a smaller Llama GGUF model for speculative decoding."""

    def __init__(self, draft: Llama, num_pred_tokens: int = 16):
        self.draft = draft
        self.num_pred_tokens = num_pred_tokens

    def __call__(
        self, input_ids: npt.NDArray[np.intc], /, **kwargs
    ) -> npt.NDArray[np.intc]:
        tokens = input_ids.tolist()
        if not tokens:
            return np.array([], dtype=np.intc)

        self.draft.reset()
        self.draft.eval(tokens)

        predicted: list[int] = []
        for _ in range(self.num_pred_tokens):
            token = self.draft.sample(temp=0.0)
            predicted.append(token)
            self.draft.eval([token])

        return np.array(predicted, dtype=np.intc)


print("Downloading draft model...")
draft_path = hf_hub_download(repo_id=DRAFT_REPO, filename=DRAFT_FILE)

print("Downloading main model...")
main_path = hf_hub_download(repo_id=MAIN_REPO, filename=MAIN_FILE)

print("Loading models into memory...")
draft_model = Llama(
    model_path=draft_path,
    n_ctx=2048,
    n_batch=512,
    n_threads=2,
    verbose=False,
)

main_model = Llama(
    model_path=main_path,
    n_ctx=2048,
    n_batch=512,
    n_threads=2,
    draft_model=LlamaModelDraft(draft_model, num_pred_tokens=16),
    verbose=False,
)
print("Models successfully loaded!")


@app.route("/v1/chat/completions", methods=["POST"])
def chat_completions():
    data = request.json or {}
    if "messages" not in data:
        return jsonify({"error": "Missing messages array"}), 400

    start_time = time.time()

    response = main_model.create_chat_completion(
        messages=data.get("messages", []),
        temperature=0.7,
        max_tokens=data.get("max_tokens", 512),
        stream=False,
    )

    generation_time = time.time() - start_time
    completion_tokens = response["usage"]["completion_tokens"]
    tps = completion_tokens / generation_time if generation_time > 0 else 0

    response["system_performance"] = {
        "tokens_per_second": round(tps, 2),
        "generation_time_sec": round(generation_time, 2),
        "acceleration_technique": "Speculative Decoding (135M draft + 360M main)",
    }

    return jsonify(response)


if __name__ == "__main__":
    app.run(host="0.0.0.0", port=PORT)