stok-sub-1 / stokfile.py
tyraepaul's picture
Upload stokfile.py
06454d2 verified
import run_stok
import json
from flask import Flask, jsonify, Response, stream_with_context, request
from datetime import datetime
import sys
from run_stok import load_model, run_model
import time
class custom_colors:
def use_hex(self, hex_code: str):
hex_code = hex_code.lstrip("#") # Remove the "#" symbol if present
rgb = tuple(int(hex_code[i:i+2], 16) for i in (0, 2, 4)) # Convert hex to RGB
return f"\033[38;2;{rgb[0]};{rgb[1]};{rgb[2]}m"
def use_hex_bg(self, hex_code: str):
hex_code = hex_code.lstrip("#")
if len(hex_code) != 6:
raise ValueError("Hex code must be 6 characters long (e.g., 'RRGGBB').")
rgb = tuple(int(hex_code[i:i+2], 16) for i in (0, 2, 4))
return f"\033[48;2;{rgb[0]};{rgb[1]};{rgb[2]}m{self.use_hex('000000')}"
def __init__(self):
self.cream = "\033[38;2;255;245;151m"
self.beige = "\033[38;2;245;245;220m"
self.green = "\033[38;2;52;199;89m"
self.red = self.use_hex("#d53410")
self.esc = "\033[0m"
colors = custom_colors()
def help_command(command):
split_command = command.split(sep=" ")
new = f"{colors.use_hex('#3498db')}{split_command[0]}{colors.esc}"
split_command[0] = new
for x in range(1, len(split_command)):
if split_command[x].startswith("<"):
split_command[x] = f"{colors.use_hex('#e74c3c')}{split_command[x]}{colors.esc}"
print(" ".join(split_command))
total = []
model = "stok-0.3.json"
help_ran = False
show_speed = False
server_mode = False
server_port = "8008"
server_host = "127.0.0.1"
if len(sys.argv) > 1: # it is set up like this to add more parameters in the future
if sys.argv[1] == "help":
help_ran = True
help_command("help - shows this command")
help_command("-m <model> - specifies the file you want to inference")
help_command("-speed - if added, enables speed logging")
help_command("-server - if added, runs as flask server with OpenAI compatibility")
help_command("--host <host_name> - something like 0.0.0.0 to change where you host")
help_command("--port <port> - changes port number for server mode")
print()
args = list(sys.argv)
running = True
while running:
if len(args) < 2:
running = False
elif args[1] == "-m":
model = args[2]
args.pop(1)
args.pop(1)
elif args[1] == "-speed":
show_speed = True
args.pop(1)
elif args[1] == "-server":
server_mode = True
args.pop(1)
elif args[1] == "--port":
server_port = args[2]
args.pop(1)
args.pop(1)
elif args[1] == "--host":
server_host = args[2]
args.pop(1)
args.pop(1)
else:
running = False
if not help_ran and not server_mode:
start = time.time()
load_model(model)
end = time.time()
print(f"took {end-start}s to load file")
running = True
while running:
total = []
message = input(">>>")
if message == "/quit" or message == "/exit" or message == "/bye":
running = False
else:
chunks = run_model(message, max_tokens=100, repetition_penalty=2)
start = time.time()
for chunk in chunks:
total.append(chunk)
print(chunk, end="")
end = time.time()
print()
if show_speed:
print(f"Took: {end-start}s")
print(f"Generated: {len(total)}")
print(f"Speed: {len(total)/(end-start)} t/s")
print("_____________________________")
if server_mode:
print("loading model...")
start = time.time()
load_model(model)
end = time.time()
print(f"took {end-start}s to load file")
app = Flask(__name__)
@app.route("/v1/chat/completions", methods=["POST"])
def chat_completions():
auth_header = request.headers.get('Authorization')
data = request.json
messages = data.get('messages', [])
prompt = data.get("prompt", None)
if prompt != None:
print("prompt used")
messages.append({"role": "user", "content": prompt})
temperature = data.get('temperature', 0.7)
max_tokens = data.get('max_tokens', 200)
repetition_penalty = data.get("repetition_penalty", 2)
# tools not currently supported
# tools = data.get("tools", None)
stream = data.get('stream', False)
# response_format not currently supported
# response_format = data.get('response_format', None)
streaming_generated_chunks = []
try:
if stream:
def generate(max_tokens, temperature):
try:
message = messages[-1]["content"]
generated_chunks = 0
completion = run_model(message, max_tokens=max_tokens, repetition_penalty=repetition_penalty)
for chunk in completion:
generated_chunks += 1
first_chunk = {
"id": f"chatcmpl-{datetime.now().timestamp()}",
"object": "chat.completion.chunk",
"created": int(datetime.now().timestamp()),
"model": model,
"choices": [{
"index": 0,
"delta": {
"content": chunk,
},
"finish_reason": None
}]
}
yield f"data: {json.dumps(first_chunk)}\n\n"
output_tokens = generated_chunks
input_tokens = len(message.split(sep=None))
final_chunk = {
"id": f"chatcmpl-{datetime.now().timestamp()}",
"object": "chat.completion.chunk",
"created": int(datetime.now().timestamp()),
"model": model,
"choices": [{
"index": 0,
"delta": {},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": input_tokens,
"completion_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens,
},
}
yield f"data: {json.dumps(final_chunk)}\n\n"
yield "data: [DONE]\n\n"
except RuntimeError as e:
yield f"event: error\ndata: {json.dumps({'error': {'message': e, 'type': 'internal_error', 'code': 500}})}\n\n"
try:
return Response(
stream_with_context(generate(max_tokens, temperature)),
content_type='text/event-stream'
)
except Exception as e:
return jsonify({
"error": {
"message": e,
"type": "internal_error",
"code": 500
}
}), 500
else:
final_content = []
message = messages[-1]["content"]
output_tokens = 0
completion = run_model(message, max_tokens=max_tokens, repetition_penalty=2)
for chunk in completion:
output_tokens += 1
final_content.append(chunk)
final_content = "".join(final_content)
input_tokens = len(message.split(sep=None))
response = {
"id": f"chatcmpl-{datetime.now().timestamp()}",
"object": "chat.completion",
"created": int(datetime.now().timestamp()),
"model": model,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": final_content
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": input_tokens,
"completion_tokens": output_tokens,
"total_tokens": output_tokens + input_tokens
}
}
return response
except Exception as e:
return jsonify({
"error": {
"message": e,
"type": "internal_error",
"code": 500
}
}), 500
app.run(host=server_host, port=server_port)