File size: 5,003 Bytes
22ac8a5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 | import os
import sys
import site
try:
cudnn_path = os.path.join(site.getsitepackages()[0], 'nvidia', 'cudnn', 'lib')
if os.path.exists(cudnn_path):
if 'LD_LIBRARY_PATH' in os.environ:
os.environ['LD_LIBRARY_PATH'] = f"{cudnn_path}:{os.environ['LD_LIBRARY_PATH']}"
else:
os.environ['LD_LIBRARY_PATH'] = cudnn_path
if "RESTARTED" not in os.environ:
os.environ["RESTARTED"] = "1"
os.execv(sys.executable, [sys.executable] + sys.argv)
except Exception:
pass
import onnxruntime as ort
import tiktoken
import numpy as np
import time
# --- Configuration ---
MODEL_PATH = "Apex_1.5_DYNAMIC.onnx"
VOCAB_SIZE = 50304
enc = tiktoken.get_encoding("gpt2")
# Setup ONNX Session with CUDA
options = ort.SessionOptions()
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
print(f"🚀 Loading Dynamic ONNX Model: {MODEL_PATH}...")
providers = [
('CUDAExecutionProvider', {
'device_id': 0,
'arena_extend_strategy': 'kNextPowerOfTwo',
}),
'CPUExecutionProvider'
]
try:
session = ort.InferenceSession(MODEL_PATH, sess_options=options, providers=providers)
print(f"✅ Active Provider: {session.get_providers()[0]}")
except Exception as e:
print(f"❌ Error loading model: {e}")
sys.exit()
def get_param(prompt, default):
"""Reads input and returns default if empty."""
val = input(f"{prompt} (Default: {default}): ").strip()
if not val:
return default
return type(default)(val)
def apply_sampling(logits, temperature, top_k, repetition_penalty, history):
"""
Applies Top-K, Temperature and Repetition Penalty to logits.
"""
# 1. Repetition Penalty
if repetition_penalty != 1.0 and len(history) > 0:
unique_tokens = np.unique(history)
# Apply penalty: divide positive logits, multiply negative ones
for token in unique_tokens:
if token < len(logits):
if logits[token] > 0:
logits[token] /= repetition_penalty
else:
logits[token] *= repetition_penalty
# 2. Temperature Scaling
logits = logits / max(temperature, 1e-6)
# 3. Top-K Sampling
top_k = min(top_k, len(logits))
indices_to_remove = logits < np.partition(logits, -top_k)[-top_k]
logits[indices_to_remove] = -float('Inf')
# 4. Softmax and Random Choice
exp_logits = np.exp(logits - np.max(logits))
probs = exp_logits / np.sum(exp_logits)
return int(np.random.choice(len(logits), p=probs))
def run_chat():
print("\n" + "="*50)
print(" APEX 1.5 DYNAMIC ONNX INTERACTIVE CHAT")
print("="*50 + "\n")
while True:
user_input = input("You: ")
if user_input.lower() in ["exit", "quit", "beenden"]:
break
# Prompt Parameters
temp = get_param(" Temperature", 0.55)
tk = get_param(" Top-K", 40)
rp = get_param(" Repetition Penalty", 1.2)
max_tk = get_param(" Max New Tokens", 500)
# Tokenize and Setup
prompt = f"Instruction:\n{user_input}\n\nResponse:\n"
input_ids = enc.encode(prompt)
history = list(input_ids)
print("\nApex 1.5: ", end="", flush=True)
start_time = time.time()
token_count = 0
last_printed_len = 0
full_response_ids = []
# Generation Loop
for _ in range(max_tk):
# Dynamic Input Shape (1, Sequence_Length)
# We take the last 1024 tokens if it grows too long
current_ctx = input_ids[-1024:]
input_array = np.array([current_ctx], dtype=np.int64)
# Run ONNX Inference
outputs = session.run(None, {'input': input_array})
# Extract Logits for the last token [Batch, Seq, Vocab]
# Since it's dynamic, we grab index -1
logits = outputs[0][0, -1, :VOCAB_SIZE].astype(np.float32)
# Sampling Logic
next_token = apply_sampling(logits, temp, tk, rp, history)
if next_token == enc.eot_token or next_token >= 50257:
break
# Update state
input_ids.append(next_token)
full_response_ids.append(next_token)
history.append(next_token)
token_count += 1
# Decode and Print
decoded_text = enc.decode(full_response_ids)
new_text = decoded_text[last_printed_len:]
# Simple Stop Condition
if "Instruction:" in new_text:
break
print(new_text, end="", flush=True)
last_printed_len = len(decoded_text)
duration = time.time() - start_time
tps = token_count / duration if duration > 0 else 0
print(f"\n\n[Speed: {tps:.2f} tokens/s | Time: {duration:.2f}s]")
print("-" * 40 + "\n")
if __name__ == "__main__":
run_chat() |