gradio / app.py
GamerC0der's picture
Update app.py
9fb99fe verified
import time
import numpy as np
import numpy.typing as npt
from flask import Flask, jsonify, request
from huggingface_hub import hf_hub_download
from llama_cpp import Llama
from llama_cpp.llama_speculative import LlamaDraftModel
app = Flask(__name__)
PORT = 7860
DRAFT_REPO = "mradermacher/SmolLM2-135M-Instruct-GGUF"
DRAFT_FILE = "SmolLM2-135M-Instruct.Q4_K_M.gguf"
MAIN_REPO = "mradermacher/SmolLM2-360M-Instruct-GGUF"
MAIN_FILE = "SmolLM2-360M-Instruct.Q4_K_M.gguf"
class LlamaModelDraft(LlamaDraftModel):
"""Wrap a smaller Llama GGUF model for speculative decoding."""
def __init__(self, draft: Llama, num_pred_tokens: int = 16):
self.draft = draft
self.num_pred_tokens = num_pred_tokens
def __call__(
self, input_ids: npt.NDArray[np.intc], /, **kwargs
) -> npt.NDArray[np.intc]:
tokens = input_ids.tolist()
if not tokens:
return np.array([], dtype=np.intc)
self.draft.reset()
self.draft.eval(tokens)
predicted: list[int] = []
for _ in range(self.num_pred_tokens):
token = self.draft.sample(temp=0.0)
predicted.append(token)
self.draft.eval([token])
return np.array(predicted, dtype=np.intc)
print("Downloading draft model...")
draft_path = hf_hub_download(repo_id=DRAFT_REPO, filename=DRAFT_FILE)
print("Downloading main model...")
main_path = hf_hub_download(repo_id=MAIN_REPO, filename=MAIN_FILE)
print("Loading models into memory...")
draft_model = Llama(
model_path=draft_path,
n_ctx=2048,
n_batch=512,
n_threads=2,
verbose=False,
)
main_model = Llama(
model_path=main_path,
n_ctx=2048,
n_batch=512,
n_threads=2,
draft_model=LlamaModelDraft(draft_model, num_pred_tokens=16),
verbose=False,
)
print("Models successfully loaded!")
@app.route("/v1/chat/completions", methods=["POST"])
def chat_completions():
data = request.json or {}
if "messages" not in data:
return jsonify({"error": "Missing messages array"}), 400
start_time = time.time()
response = main_model.create_chat_completion(
messages=data.get("messages", []),
temperature=0.7,
max_tokens=data.get("max_tokens", 512),
stream=False,
)
generation_time = time.time() - start_time
completion_tokens = response["usage"]["completion_tokens"]
tps = completion_tokens / generation_time if generation_time > 0 else 0
response["system_performance"] = {
"tokens_per_second": round(tps, 2),
"generation_time_sec": round(generation_time, 2),
"acceleration_technique": "Speculative Decoding (135M draft + 360M main)",
}
return jsonify(response)
if __name__ == "__main__":
app.run(host="0.0.0.0", port=PORT)