File size: 5,003 Bytes
22ac8a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import os
import sys
import site

try:
    cudnn_path = os.path.join(site.getsitepackages()[0], 'nvidia', 'cudnn', 'lib')
    if os.path.exists(cudnn_path):
        if 'LD_LIBRARY_PATH' in os.environ:
            os.environ['LD_LIBRARY_PATH'] = f"{cudnn_path}:{os.environ['LD_LIBRARY_PATH']}"
        else:
            os.environ['LD_LIBRARY_PATH'] = cudnn_path
        if "RESTARTED" not in os.environ:
            os.environ["RESTARTED"] = "1"
            os.execv(sys.executable, [sys.executable] + sys.argv)
except Exception:
    pass

import onnxruntime as ort

import tiktoken
import numpy as np
import time

# --- Configuration ---
MODEL_PATH = "Apex_1.5_DYNAMIC.onnx"
VOCAB_SIZE = 50304 
enc = tiktoken.get_encoding("gpt2")

# Setup ONNX Session with CUDA
options = ort.SessionOptions()
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL

print(f"🚀 Loading Dynamic ONNX Model: {MODEL_PATH}...")
providers = [
    ('CUDAExecutionProvider', {
        'device_id': 0,
        'arena_extend_strategy': 'kNextPowerOfTwo',
    }),
    'CPUExecutionProvider'
]

try:
    session = ort.InferenceSession(MODEL_PATH, sess_options=options, providers=providers)
    print(f"✅ Active Provider: {session.get_providers()[0]}")
except Exception as e:
    print(f"❌ Error loading model: {e}")
    sys.exit()

def get_param(prompt, default):
    """Reads input and returns default if empty."""
    val = input(f"{prompt} (Default: {default}): ").strip()
    if not val:
        return default
    return type(default)(val)

def apply_sampling(logits, temperature, top_k, repetition_penalty, history):
    """
    Applies Top-K, Temperature and Repetition Penalty to logits.
    """
    # 1. Repetition Penalty
    if repetition_penalty != 1.0 and len(history) > 0:
        unique_tokens = np.unique(history)
        # Apply penalty: divide positive logits, multiply negative ones
        for token in unique_tokens:
            if token < len(logits):
                if logits[token] > 0:
                    logits[token] /= repetition_penalty
                else:
                    logits[token] *= repetition_penalty

    # 2. Temperature Scaling
    logits = logits / max(temperature, 1e-6)

    # 3. Top-K Sampling
    top_k = min(top_k, len(logits))
    indices_to_remove = logits < np.partition(logits, -top_k)[-top_k]
    logits[indices_to_remove] = -float('Inf')

    # 4. Softmax and Random Choice
    exp_logits = np.exp(logits - np.max(logits))
    probs = exp_logits / np.sum(exp_logits)
    
    return int(np.random.choice(len(logits), p=probs))

def run_chat():
    print("\n" + "="*50)
    print("   APEX 1.5 DYNAMIC ONNX INTERACTIVE CHAT")
    print("="*50 + "\n")

    while True:
        user_input = input("You: ")
        if user_input.lower() in ["exit", "quit", "beenden"]:
            break

        # Prompt Parameters
        temp = get_param("  Temperature", 0.55)
        tk = get_param("  Top-K", 40)
        rp = get_param("  Repetition Penalty", 1.2)
        max_tk = get_param("  Max New Tokens", 500)

        # Tokenize and Setup
        prompt = f"Instruction:\n{user_input}\n\nResponse:\n"
        input_ids = enc.encode(prompt)
        history = list(input_ids)

        print("\nApex 1.5: ", end="", flush=True)
        
        start_time = time.time()
        token_count = 0
        last_printed_len = 0
        full_response_ids = []

        # Generation Loop
        for _ in range(max_tk):
            # Dynamic Input Shape (1, Sequence_Length)
            # We take the last 1024 tokens if it grows too long
            current_ctx = input_ids[-1024:]
            input_array = np.array([current_ctx], dtype=np.int64)

            # Run ONNX Inference
            outputs = session.run(None, {'input': input_array})
            
            # Extract Logits for the last token [Batch, Seq, Vocab]
            # Since it's dynamic, we grab index -1
            logits = outputs[0][0, -1, :VOCAB_SIZE].astype(np.float32)

            # Sampling Logic
            next_token = apply_sampling(logits, temp, tk, rp, history)

            if next_token == enc.eot_token or next_token >= 50257:
                break

            # Update state
            input_ids.append(next_token)
            full_response_ids.append(next_token)
            history.append(next_token)
            token_count += 1

            # Decode and Print
            decoded_text = enc.decode(full_response_ids)
            new_text = decoded_text[last_printed_len:]
            
            # Simple Stop Condition
            if "Instruction:" in new_text:
                break
                
            print(new_text, end="", flush=True)
            last_printed_len = len(decoded_text)

        duration = time.time() - start_time
        tps = token_count / duration if duration > 0 else 0
        
        print(f"\n\n[Speed: {tps:.2f} tokens/s | Time: {duration:.2f}s]")
        print("-" * 40 + "\n")

if __name__ == "__main__":
    run_chat()