| import run_stok
|
| import json
|
| from flask import Flask, jsonify, Response, stream_with_context, request
|
| from datetime import datetime
|
| import sys
|
| from run_stok import load_model, run_model
|
| import time
|
|
|
| class custom_colors:
|
| def use_hex(self, hex_code: str):
|
| hex_code = hex_code.lstrip("#")
|
| rgb = tuple(int(hex_code[i:i+2], 16) for i in (0, 2, 4))
|
| return f"\033[38;2;{rgb[0]};{rgb[1]};{rgb[2]}m"
|
|
|
| def use_hex_bg(self, hex_code: str):
|
| hex_code = hex_code.lstrip("#")
|
| if len(hex_code) != 6:
|
| raise ValueError("Hex code must be 6 characters long (e.g., 'RRGGBB').")
|
| rgb = tuple(int(hex_code[i:i+2], 16) for i in (0, 2, 4))
|
| return f"\033[48;2;{rgb[0]};{rgb[1]};{rgb[2]}m{self.use_hex('000000')}"
|
|
|
| def __init__(self):
|
| self.cream = "\033[38;2;255;245;151m"
|
| self.beige = "\033[38;2;245;245;220m"
|
| self.green = "\033[38;2;52;199;89m"
|
| self.red = self.use_hex("#d53410")
|
| self.esc = "\033[0m"
|
|
|
| colors = custom_colors()
|
|
|
| def help_command(command):
|
| split_command = command.split(sep=" ")
|
| new = f"{colors.use_hex('#3498db')}{split_command[0]}{colors.esc}"
|
| split_command[0] = new
|
| for x in range(1, len(split_command)):
|
| if split_command[x].startswith("<"):
|
| split_command[x] = f"{colors.use_hex('#e74c3c')}{split_command[x]}{colors.esc}"
|
|
|
| print(" ".join(split_command))
|
|
|
| total = []
|
| model = "stok-0.3.json"
|
| help_ran = False
|
| show_speed = False
|
| server_mode = False
|
| server_port = "8008"
|
| server_host = "127.0.0.1"
|
| if len(sys.argv) > 1:
|
| if sys.argv[1] == "help":
|
| help_ran = True
|
| help_command("help - shows this command")
|
| help_command("-m <model> - specifies the file you want to inference")
|
| help_command("-speed - if added, enables speed logging")
|
| help_command("-server - if added, runs as flask server with OpenAI compatibility")
|
| help_command("--host <host_name> - something like 0.0.0.0 to change where you host")
|
| help_command("--port <port> - changes port number for server mode")
|
| print()
|
| args = list(sys.argv)
|
| running = True
|
| while running:
|
| if len(args) < 2:
|
| running = False
|
| elif args[1] == "-m":
|
| model = args[2]
|
| args.pop(1)
|
| args.pop(1)
|
| elif args[1] == "-speed":
|
| show_speed = True
|
| args.pop(1)
|
| elif args[1] == "-server":
|
| server_mode = True
|
| args.pop(1)
|
| elif args[1] == "--port":
|
| server_port = args[2]
|
| args.pop(1)
|
| args.pop(1)
|
| elif args[1] == "--host":
|
| server_host = args[2]
|
| args.pop(1)
|
| args.pop(1)
|
| else:
|
| running = False
|
|
|
|
|
| if not help_ran and not server_mode:
|
| start = time.time()
|
| load_model(model)
|
| end = time.time()
|
| print(f"took {end-start}s to load file")
|
| running = True
|
| while running:
|
| total = []
|
| message = input(">>>")
|
| if message == "/quit" or message == "/exit" or message == "/bye":
|
| running = False
|
| else:
|
| chunks = run_model(message, max_tokens=100, repetition_penalty=2)
|
| start = time.time()
|
| for chunk in chunks:
|
| total.append(chunk)
|
| print(chunk, end="")
|
| end = time.time()
|
| print()
|
| if show_speed:
|
| print(f"Took: {end-start}s")
|
| print(f"Generated: {len(total)}")
|
| print(f"Speed: {len(total)/(end-start)} t/s")
|
| print("_____________________________")
|
|
|
| if server_mode:
|
| print("loading model...")
|
| start = time.time()
|
| load_model(model)
|
| end = time.time()
|
| print(f"took {end-start}s to load file")
|
| app = Flask(__name__)
|
| @app.route("/v1/chat/completions", methods=["POST"])
|
| def chat_completions():
|
| auth_header = request.headers.get('Authorization')
|
| data = request.json
|
| messages = data.get('messages', [])
|
| prompt = data.get("prompt", None)
|
| if prompt != None:
|
| print("prompt used")
|
| messages.append({"role": "user", "content": prompt})
|
|
|
| temperature = data.get('temperature', 0.7)
|
| max_tokens = data.get('max_tokens', 200)
|
| repetition_penalty = data.get("repetition_penalty", 2)
|
|
|
|
|
|
|
|
|
| stream = data.get('stream', False)
|
|
|
|
|
|
|
|
|
| streaming_generated_chunks = []
|
| try:
|
| if stream:
|
| def generate(max_tokens, temperature):
|
| try:
|
| message = messages[-1]["content"]
|
| generated_chunks = 0
|
| completion = run_model(message, max_tokens=max_tokens, repetition_penalty=repetition_penalty)
|
| for chunk in completion:
|
| generated_chunks += 1
|
| first_chunk = {
|
| "id": f"chatcmpl-{datetime.now().timestamp()}",
|
| "object": "chat.completion.chunk",
|
| "created": int(datetime.now().timestamp()),
|
| "model": model,
|
| "choices": [{
|
| "index": 0,
|
| "delta": {
|
| "content": chunk,
|
| },
|
| "finish_reason": None
|
| }]
|
| }
|
| yield f"data: {json.dumps(first_chunk)}\n\n"
|
| output_tokens = generated_chunks
|
| input_tokens = len(message.split(sep=None))
|
|
|
| final_chunk = {
|
| "id": f"chatcmpl-{datetime.now().timestamp()}",
|
| "object": "chat.completion.chunk",
|
| "created": int(datetime.now().timestamp()),
|
| "model": model,
|
| "choices": [{
|
| "index": 0,
|
| "delta": {},
|
| "finish_reason": "stop"
|
| }],
|
| "usage": {
|
| "prompt_tokens": input_tokens,
|
| "completion_tokens": output_tokens,
|
| "total_tokens": input_tokens + output_tokens,
|
| },
|
| }
|
| yield f"data: {json.dumps(final_chunk)}\n\n"
|
| yield "data: [DONE]\n\n"
|
| except RuntimeError as e:
|
| yield f"event: error\ndata: {json.dumps({'error': {'message': e, 'type': 'internal_error', 'code': 500}})}\n\n"
|
|
|
|
|
|
|
| try:
|
| return Response(
|
| stream_with_context(generate(max_tokens, temperature)),
|
| content_type='text/event-stream'
|
| )
|
| except Exception as e:
|
| return jsonify({
|
| "error": {
|
| "message": e,
|
| "type": "internal_error",
|
| "code": 500
|
| }
|
| }), 500
|
|
|
|
|
| else:
|
| final_content = []
|
| message = messages[-1]["content"]
|
| output_tokens = 0
|
| completion = run_model(message, max_tokens=max_tokens, repetition_penalty=2)
|
| for chunk in completion:
|
| output_tokens += 1
|
| final_content.append(chunk)
|
| final_content = "".join(final_content)
|
| input_tokens = len(message.split(sep=None))
|
| response = {
|
| "id": f"chatcmpl-{datetime.now().timestamp()}",
|
| "object": "chat.completion",
|
| "created": int(datetime.now().timestamp()),
|
| "model": model,
|
| "choices": [{
|
| "index": 0,
|
| "message": {
|
| "role": "assistant",
|
| "content": final_content
|
| },
|
| "finish_reason": "stop"
|
| }],
|
| "usage": {
|
| "prompt_tokens": input_tokens,
|
| "completion_tokens": output_tokens,
|
| "total_tokens": output_tokens + input_tokens
|
| }
|
| }
|
| return response
|
|
|
| except Exception as e:
|
| return jsonify({
|
| "error": {
|
| "message": e,
|
| "type": "internal_error",
|
| "code": 500
|
| }
|
| }), 500
|
| app.run(host=server_host, port=server_port)
|
|
|
|
|