Apex-1.5-Instruct-350M / inference.py
LH-Tech-AI's picture
Create inference.py
22ac8a5 verified
import os
import sys
import site
try:
cudnn_path = os.path.join(site.getsitepackages()[0], 'nvidia', 'cudnn', 'lib')
if os.path.exists(cudnn_path):
if 'LD_LIBRARY_PATH' in os.environ:
os.environ['LD_LIBRARY_PATH'] = f"{cudnn_path}:{os.environ['LD_LIBRARY_PATH']}"
else:
os.environ['LD_LIBRARY_PATH'] = cudnn_path
if "RESTARTED" not in os.environ:
os.environ["RESTARTED"] = "1"
os.execv(sys.executable, [sys.executable] + sys.argv)
except Exception:
pass
import onnxruntime as ort
import tiktoken
import numpy as np
import time
# --- Configuration ---
MODEL_PATH = "Apex_1.5_DYNAMIC.onnx"
VOCAB_SIZE = 50304
enc = tiktoken.get_encoding("gpt2")
# Setup ONNX Session with CUDA
options = ort.SessionOptions()
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
print(f"🚀 Loading Dynamic ONNX Model: {MODEL_PATH}...")
providers = [
('CUDAExecutionProvider', {
'device_id': 0,
'arena_extend_strategy': 'kNextPowerOfTwo',
}),
'CPUExecutionProvider'
]
try:
session = ort.InferenceSession(MODEL_PATH, sess_options=options, providers=providers)
print(f"✅ Active Provider: {session.get_providers()[0]}")
except Exception as e:
print(f"❌ Error loading model: {e}")
sys.exit()
def get_param(prompt, default):
"""Reads input and returns default if empty."""
val = input(f"{prompt} (Default: {default}): ").strip()
if not val:
return default
return type(default)(val)
def apply_sampling(logits, temperature, top_k, repetition_penalty, history):
"""
Applies Top-K, Temperature and Repetition Penalty to logits.
"""
# 1. Repetition Penalty
if repetition_penalty != 1.0 and len(history) > 0:
unique_tokens = np.unique(history)
# Apply penalty: divide positive logits, multiply negative ones
for token in unique_tokens:
if token < len(logits):
if logits[token] > 0:
logits[token] /= repetition_penalty
else:
logits[token] *= repetition_penalty
# 2. Temperature Scaling
logits = logits / max(temperature, 1e-6)
# 3. Top-K Sampling
top_k = min(top_k, len(logits))
indices_to_remove = logits < np.partition(logits, -top_k)[-top_k]
logits[indices_to_remove] = -float('Inf')
# 4. Softmax and Random Choice
exp_logits = np.exp(logits - np.max(logits))
probs = exp_logits / np.sum(exp_logits)
return int(np.random.choice(len(logits), p=probs))
def run_chat():
print("\n" + "="*50)
print(" APEX 1.5 DYNAMIC ONNX INTERACTIVE CHAT")
print("="*50 + "\n")
while True:
user_input = input("You: ")
if user_input.lower() in ["exit", "quit", "beenden"]:
break
# Prompt Parameters
temp = get_param(" Temperature", 0.55)
tk = get_param(" Top-K", 40)
rp = get_param(" Repetition Penalty", 1.2)
max_tk = get_param(" Max New Tokens", 500)
# Tokenize and Setup
prompt = f"Instruction:\n{user_input}\n\nResponse:\n"
input_ids = enc.encode(prompt)
history = list(input_ids)
print("\nApex 1.5: ", end="", flush=True)
start_time = time.time()
token_count = 0
last_printed_len = 0
full_response_ids = []
# Generation Loop
for _ in range(max_tk):
# Dynamic Input Shape (1, Sequence_Length)
# We take the last 1024 tokens if it grows too long
current_ctx = input_ids[-1024:]
input_array = np.array([current_ctx], dtype=np.int64)
# Run ONNX Inference
outputs = session.run(None, {'input': input_array})
# Extract Logits for the last token [Batch, Seq, Vocab]
# Since it's dynamic, we grab index -1
logits = outputs[0][0, -1, :VOCAB_SIZE].astype(np.float32)
# Sampling Logic
next_token = apply_sampling(logits, temp, tk, rp, history)
if next_token == enc.eot_token or next_token >= 50257:
break
# Update state
input_ids.append(next_token)
full_response_ids.append(next_token)
history.append(next_token)
token_count += 1
# Decode and Print
decoded_text = enc.decode(full_response_ids)
new_text = decoded_text[last_printed_len:]
# Simple Stop Condition
if "Instruction:" in new_text:
break
print(new_text, end="", flush=True)
last_printed_len = len(decoded_text)
duration = time.time() - start_time
tps = token_count / duration if duration > 0 else 0
print(f"\n\n[Speed: {tps:.2f} tokens/s | Time: {duration:.2f}s]")
print("-" * 40 + "\n")
if __name__ == "__main__":
run_chat()