| | import os |
| | import sys |
| | import site |
| |
|
| | try: |
| | cudnn_path = os.path.join(site.getsitepackages()[0], 'nvidia', 'cudnn', 'lib') |
| | if os.path.exists(cudnn_path): |
| | if 'LD_LIBRARY_PATH' in os.environ: |
| | os.environ['LD_LIBRARY_PATH'] = f"{cudnn_path}:{os.environ['LD_LIBRARY_PATH']}" |
| | else: |
| | os.environ['LD_LIBRARY_PATH'] = cudnn_path |
| | if "RESTARTED" not in os.environ: |
| | os.environ["RESTARTED"] = "1" |
| | os.execv(sys.executable, [sys.executable] + sys.argv) |
| | except Exception: |
| | pass |
| |
|
| | import onnxruntime as ort |
| |
|
| | import tiktoken |
| | import numpy as np |
| | import time |
| |
|
| | |
| | MODEL_PATH = "Apex_1.5_DYNAMIC.onnx" |
| | VOCAB_SIZE = 50304 |
| | enc = tiktoken.get_encoding("gpt2") |
| |
|
| | |
| | options = ort.SessionOptions() |
| | options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL |
| |
|
| | print(f"🚀 Loading Dynamic ONNX Model: {MODEL_PATH}...") |
| | providers = [ |
| | ('CUDAExecutionProvider', { |
| | 'device_id': 0, |
| | 'arena_extend_strategy': 'kNextPowerOfTwo', |
| | }), |
| | 'CPUExecutionProvider' |
| | ] |
| |
|
| | try: |
| | session = ort.InferenceSession(MODEL_PATH, sess_options=options, providers=providers) |
| | print(f"✅ Active Provider: {session.get_providers()[0]}") |
| | except Exception as e: |
| | print(f"❌ Error loading model: {e}") |
| | sys.exit() |
| |
|
| | def get_param(prompt, default): |
| | """Reads input and returns default if empty.""" |
| | val = input(f"{prompt} (Default: {default}): ").strip() |
| | if not val: |
| | return default |
| | return type(default)(val) |
| |
|
| | def apply_sampling(logits, temperature, top_k, repetition_penalty, history): |
| | """ |
| | Applies Top-K, Temperature and Repetition Penalty to logits. |
| | """ |
| | |
| | if repetition_penalty != 1.0 and len(history) > 0: |
| | unique_tokens = np.unique(history) |
| | |
| | for token in unique_tokens: |
| | if token < len(logits): |
| | if logits[token] > 0: |
| | logits[token] /= repetition_penalty |
| | else: |
| | logits[token] *= repetition_penalty |
| |
|
| | |
| | logits = logits / max(temperature, 1e-6) |
| |
|
| | |
| | top_k = min(top_k, len(logits)) |
| | indices_to_remove = logits < np.partition(logits, -top_k)[-top_k] |
| | logits[indices_to_remove] = -float('Inf') |
| |
|
| | |
| | exp_logits = np.exp(logits - np.max(logits)) |
| | probs = exp_logits / np.sum(exp_logits) |
| | |
| | return int(np.random.choice(len(logits), p=probs)) |
| |
|
| | def run_chat(): |
| | print("\n" + "="*50) |
| | print(" APEX 1.5 DYNAMIC ONNX INTERACTIVE CHAT") |
| | print("="*50 + "\n") |
| |
|
| | while True: |
| | user_input = input("You: ") |
| | if user_input.lower() in ["exit", "quit", "beenden"]: |
| | break |
| |
|
| | |
| | temp = get_param(" Temperature", 0.55) |
| | tk = get_param(" Top-K", 40) |
| | rp = get_param(" Repetition Penalty", 1.2) |
| | max_tk = get_param(" Max New Tokens", 500) |
| |
|
| | |
| | prompt = f"Instruction:\n{user_input}\n\nResponse:\n" |
| | input_ids = enc.encode(prompt) |
| | history = list(input_ids) |
| |
|
| | print("\nApex 1.5: ", end="", flush=True) |
| | |
| | start_time = time.time() |
| | token_count = 0 |
| | last_printed_len = 0 |
| | full_response_ids = [] |
| |
|
| | |
| | for _ in range(max_tk): |
| | |
| | |
| | current_ctx = input_ids[-1024:] |
| | input_array = np.array([current_ctx], dtype=np.int64) |
| |
|
| | |
| | outputs = session.run(None, {'input': input_array}) |
| | |
| | |
| | |
| | logits = outputs[0][0, -1, :VOCAB_SIZE].astype(np.float32) |
| |
|
| | |
| | next_token = apply_sampling(logits, temp, tk, rp, history) |
| |
|
| | if next_token == enc.eot_token or next_token >= 50257: |
| | break |
| |
|
| | |
| | input_ids.append(next_token) |
| | full_response_ids.append(next_token) |
| | history.append(next_token) |
| | token_count += 1 |
| |
|
| | |
| | decoded_text = enc.decode(full_response_ids) |
| | new_text = decoded_text[last_printed_len:] |
| | |
| | |
| | if "Instruction:" in new_text: |
| | break |
| | |
| | print(new_text, end="", flush=True) |
| | last_printed_len = len(decoded_text) |
| |
|
| | duration = time.time() - start_time |
| | tps = token_count / duration if duration > 0 else 0 |
| | |
| | print(f"\n\n[Speed: {tps:.2f} tokens/s | Time: {duration:.2f}s]") |
| | print("-" * 40 + "\n") |
| |
|
| | if __name__ == "__main__": |
| | run_chat() |