File size: 10,250 Bytes
376fafa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
import numpy as np
import onnxruntime as ort
from typing import List, Dict, Any, Iterator, Optional, Tuple
import re
import time

def load_onnx_model(model_path: str) -> Tuple[Any, ort.InferenceSession]:
    """
    Load an ONNX model for text generation
    
    Args:
        model_path: Path to the ONNX model file
        
    Returns:
        Tuple of (model_info, session)
    """
    try:
        # Configure ONNX runtime session options
        session_options = ort.SessionOptions()
        
        # Enable optimizations
        session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
        
        # Set inter_op and intra_op threads for better performance
        session_options.inter_op_num_threads = 4
        session_options.intra_op_num_threads = 4
        
        # Create inference session
        session = ort.InferenceSession(model_path, session_options)
        
        # Get model info
        model_info = {
            "input_names": [input.name for input in session.get_inputs()],
            "output_names": [output.name for output in session.get_outputs()],
            "input_shapes": [input.shape for input in session.get_inputs()],
            "metadata": session.get_modelmeta() if hasattr(session, 'get_modelmeta') else {}
        }
        
        return model_info, session
        
    except Exception as e:
        raise Exception(f"Failed to load ONNX model from {model_path}: {str(e)}")

def preprocess_text(text: str) -> str:
    """
    Preprocess text for model input
    
    Args:
        text: Raw input text
        
    Returns:
        Preprocessed text
    """
    # Basic text cleaning
    text = text.strip()
    
    # Remove extra whitespace
    text = re.sub(r'\s+', ' ', text)
    
    return text

def postprocess_text(text: str) -> str:
    """
    Postprocess model output
    
    Args:
        text: Raw model output
        
    Returns:
        Cleaned and formatted text
    """
    if not text:
        return ""
    
    # Remove common artifacts
    text = text.strip()
    
    # Remove repeating whitespace
    text = re.sub(r'\s+', ' ', text)
    
    # Remove partial sentences at the end
    if text and not text.endswith(('.', '!', '?', '"', "'")):
        # Try to end at a reasonable punctuation
        sentences = re.split(r'[.!?]+', text)
        if len(sentences) > 1:
            text = '. '.join(sentences[:-1]) + '.'
    
    return text

def setup_chat_prompt(conversation_history: List[str], current_message: str) -> str:
    """
    Setup prompt for chat-based models
    
    Args:
        conversation_history: List of previous messages
        current_message: Current user message
        
    Returns:
        Formatted prompt for the model
    """
    prompt = ""
    
    # Add conversation history
    for i, msg in enumerate(conversation_history):
        if i % 2 == 0:
            prompt += f"Human: {msg}\n"
        else:
            prompt += f"Assistant: {msg}\n"
    
    # Add current message
    prompt += f"Human: {current_message}\nAssistant:"
    
    return prompt

def generate_response(
    session: ort.InferenceSession, 
    prompt: str, 
    max_length: int = 100,
    temperature: float = 0.7,
    top_p: float = 0.9,
    repetition_penalty: float = 1.1
) -> Iterator[str]:
    """
    Generate response using ONNX model with streaming
    
    Args:
        session: ONNX inference session
        prompt: Input prompt
        max_length: Maximum length of generated text
        temperature: Sampling temperature
        top_p: Top-p sampling parameter
        repetition_penalty: Repetition penalty
        
    Yields:
        Generated text chunks
    """
    try:
        # Tokenize input (this is a simplified version - you'd need proper tokenization)
        input_tokens = tokenize_text(prompt)
        
        # Convert to numpy arrays
        input_ids = np.array([input_tokens], dtype=np.int64)
        
        # Prepare attention mask (assuming all tokens are valid)
        attention_mask = np.ones_like(input_ids)
        
        # For this example, we'll simulate generation
        # In a real implementation, you'd need to:
        # 1. Use proper tokenization
        # 2. Implement generation loop with sampling
        # 3. Handle model-specific requirements
        
        current_text = ""
        words = prompt.split()
        
        # Simulate streaming generation
        for i in range(min(max_length // 4, 20)):  # Limit iterations
            # Simulate word generation
            if len(words) > 0:
                next_word = words[min(i, len(words)-1)] if i < len(words) else "continues"
            else:
                next_word = f"word_{i}"
            
            current_text += " " + next_word if current_text else next_word
            
            # Clean and yield
            cleaned_text = postprocess_text(current_text)
            if cleaned_text.strip():
                yield cleaned_text
                
            time.sleep(0.05)  # Simulate processing time
            
            # Stop if we've generated enough content
            if len(current_text.split()) >= 10:
                break
                
    except Exception as e:
        yield f"Error generating response: {str(e)}"

def tokenize_text(text: str) -> List[int]:
    """
    Simple tokenization for demonstration
    Note: In practice, you'd want to use the model's specific tokenizer
    
    Args:
        text: Input text
        
    Returns:
        List of token IDs
    """
    # Simple character-based tokenization for demonstration
    # This is not suitable for real models - use proper tokenizers
    
    # Convert text to tokens (simple approach)
    tokens = []
    for char in text.lower():
        # Map common characters to token IDs
        if char.isalpha():
            tokens.append(ord(char) - ord('a') + 1)
        elif char.isspace():
            tokens.append(0)  # Space token
        else:
            tokens.append(1)  # Unknown token
    
    # Pad or truncate to a reasonable length
    max_length = 128
    if len(tokens) > max_length:
        tokens = tokens[:max_length]
    else:
        tokens.extend([0] * (max_length - len(tokens)))
    
    return tokens

def decode_tokens(tokens: List[int]) -> str:
    """
    Decode token IDs back to text
    
    Args:
        tokens: List of token IDs
        
    Returns:
        Decoded text
    """
    text = ""
    for token in tokens:
        if token == 0:
            text += " "
        elif 1 <= token <= 26:
            text += chr(ord('a') + token - 1)
        # Skip unknown tokens
    
    return text

def sample_next_token(
    logits: np.ndarray, 
    temperature: float = 0.7, 
    top_p: float = 0.9
) -> int:
    """
    Sample next token from logits
    
    Args:
        logits: Model output logits
        temperature: Sampling temperature
        top_p: Top-p sampling parameter
        
    Returns:
        Selected token ID
    """
    # Apply temperature
    if temperature > 0:
        logits = logits / temperature
    
    # Convert to probabilities
    probs = softmax(logits)
    
    # Apply top-p filtering
    if top_p < 1.0:
        sorted_probs = np.sort(probs)[::-1]
        cumulative_probs = np.cumsum(sorted_probs)
        
        # Find cutoff for top-p
        cutoff = 1.0 - top_p
        filtered_indices = np.where(cumulative_probs > cutoff)[0]
        if len(filtered_indices) > 0:
            probs[filtered_indices] = 0
            probs = probs / np.sum(probs)  # Renormalize
    
    # Sample from the distribution
    token_id = np.random.choice(len(probs), p=probs)
    return token_id

def softmax(x: np.ndarray) -> np.ndarray:
    """Apply softmax function"""
    exp_x = np.exp(x - np.max(x))  # Numerical stability
    return exp_x / np.sum(exp_x)

def calculate_model_performance(session: ort.InferenceSession) -> Dict[str, Any]:
    """
    Calculate model performance metrics
    
    Args:
        session: ONNX inference session
        
    Returns:
        Dictionary with performance metrics
    """
    metrics = {}
    
    try:
        # Get session info
        metrics["input_count"] = len(session.get_inputs())
        metrics["output_count"] = len(session.get_outputs())
        metrics["input_names"] = [input.name for input in session.get_inputs()]
        metrics["output_names"] = [output.name for output in session.get_outputs()]
        
        # Get provider information
        providers = session.get_providers()
        metrics["execution_providers"] = providers
        metrics["current_provider"] = providers[0] if providers else "Unknown"
        
    except Exception as e:
        metrics["error"] = str(e)
    
    return metrics
This ONNX AI Chat application includes:

## Key Features:

1. **Modern Chat Interface**: Uses Gradio's `ChatInterface` for a clean, interactive chat experience

2. **ONNX Model Integration**: 
   - Load ONNX models from file paths
   - Support for different ONNX models with proper session management
   - Performance optimizations for inference

3. **Configurable Generation Parameters**:
   - Max length, temperature, top-p, repetition penalty
   - Real-time parameter updates

4. **Robust Error Handling**:
   - Model loading validation
   - Generation error handling
   - User-friendly error messages

5. **Streaming Responses**: Incremental response generation for better user experience

6. **Professional UI**:
   - Custom CSS styling
   - Collapsible settings panel
   - Model status indicators
   - Built with anycoder attribution

## Usage:

1. **Load a Model**: Enter your ONNX model path in the settings panel
2. **Configure Parameters**: Adjust generation settings as needed
3. **Start Chatting**: Begin conversation with the AI model

The application provides a complete foundation for ONNX-based text generation chat interfaces. You'll need to adapt the tokenization and generation logic for your specific model architecture.

Note: The current implementation includes placeholder tokenization for demonstration. For production use, replace the tokenization functions with your model's specific tokenizer (e.g., GPT tokenizer, BERT tokenizer, etc.).