Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -3,25 +3,24 @@ from flask import Flask, request, jsonify
|
|
| 3 |
from llama_cpp import Llama
|
| 4 |
|
| 5 |
app = Flask(__name__)
|
| 6 |
-
|
| 7 |
GRADIO_PORT = 7860
|
| 8 |
|
| 9 |
draft_model = Llama.from_pretrained(
|
| 10 |
repo_id="QuantFactory/SmolLM2-135M-Instruct-GGUF",
|
| 11 |
-
filename="SmolLM2-135M-Instruct.Q4_0.gguf",
|
| 12 |
n_ctx=2048,
|
|
|
|
| 13 |
n_threads=2,
|
| 14 |
-
flash_attn=True,
|
| 15 |
verbose=False
|
| 16 |
)
|
| 17 |
|
| 18 |
main_model = Llama.from_pretrained(
|
| 19 |
repo_id="QuantFactory/SmolLM2-360M-Instruct-GGUF",
|
| 20 |
-
filename="SmolLM2-360M-Instruct.Q4_0.gguf",
|
| 21 |
n_ctx=2048,
|
|
|
|
| 22 |
n_threads=2,
|
| 23 |
-
|
| 24 |
-
draft_model=draft_model,
|
| 25 |
verbose=False
|
| 26 |
)
|
| 27 |
|
|
@@ -30,22 +29,26 @@ def chat_completions():
|
|
| 30 |
data = request.json or {}
|
| 31 |
if 'messages' not in data:
|
| 32 |
return jsonify({"error": "Missing messages array"}), 400
|
| 33 |
-
|
| 34 |
start_time = time.time()
|
|
|
|
| 35 |
response = main_model.create_chat_completion(
|
| 36 |
messages=data.get('messages', []),
|
| 37 |
temperature=0.7,
|
| 38 |
max_tokens=data.get('max_tokens', 512),
|
| 39 |
-
stream=False
|
|
|
|
| 40 |
)
|
| 41 |
-
generation_time = time.time() - start_time
|
| 42 |
|
|
|
|
| 43 |
tps = response['usage']['completion_tokens'] / generation_time if generation_time > 0 else 0
|
|
|
|
| 44 |
response['system_performance'] = {
|
| 45 |
"tokens_per_second": round(tps, 2),
|
| 46 |
"generation_time_sec": round(generation_time, 2),
|
| 47 |
-
"acceleration_technique": "Lossless Speculative Decoding"
|
| 48 |
}
|
|
|
|
| 49 |
return jsonify(response)
|
| 50 |
|
| 51 |
if __name__ == '__main__':
|
|
|
|
| 3 |
from llama_cpp import Llama
|
| 4 |
|
| 5 |
app = Flask(__name__)
|
|
|
|
| 6 |
GRADIO_PORT = 7860
|
| 7 |
|
| 8 |
draft_model = Llama.from_pretrained(
|
| 9 |
repo_id="QuantFactory/SmolLM2-135M-Instruct-GGUF",
|
| 10 |
+
filename="*SmolLM2-135M-Instruct.Q4_0.gguf",
|
| 11 |
n_ctx=2048,
|
| 12 |
+
n_batch=512,
|
| 13 |
n_threads=2,
|
|
|
|
| 14 |
verbose=False
|
| 15 |
)
|
| 16 |
|
| 17 |
main_model = Llama.from_pretrained(
|
| 18 |
repo_id="QuantFactory/SmolLM2-360M-Instruct-GGUF",
|
| 19 |
+
filename="*SmolLM2-360M-Instruct.Q4_0.gguf",
|
| 20 |
n_ctx=2048,
|
| 21 |
+
n_batch=512,
|
| 22 |
n_threads=2,
|
| 23 |
+
draft_model=draft_model,
|
|
|
|
| 24 |
verbose=False
|
| 25 |
)
|
| 26 |
|
|
|
|
| 29 |
data = request.json or {}
|
| 30 |
if 'messages' not in data:
|
| 31 |
return jsonify({"error": "Missing messages array"}), 400
|
| 32 |
+
|
| 33 |
start_time = time.time()
|
| 34 |
+
|
| 35 |
response = main_model.create_chat_completion(
|
| 36 |
messages=data.get('messages', []),
|
| 37 |
temperature=0.7,
|
| 38 |
max_tokens=data.get('max_tokens', 512),
|
| 39 |
+
stream=False,
|
| 40 |
+
cache_prompt=True
|
| 41 |
)
|
|
|
|
| 42 |
|
| 43 |
+
generation_time = time.time() - start_time
|
| 44 |
tps = response['usage']['completion_tokens'] / generation_time if generation_time > 0 else 0
|
| 45 |
+
|
| 46 |
response['system_performance'] = {
|
| 47 |
"tokens_per_second": round(tps, 2),
|
| 48 |
"generation_time_sec": round(generation_time, 2),
|
| 49 |
+
"acceleration_technique": "Lossless Speculative Decoding + Prompt Caching"
|
| 50 |
}
|
| 51 |
+
|
| 52 |
return jsonify(response)
|
| 53 |
|
| 54 |
if __name__ == '__main__':
|