Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import onnxruntime as ort | |
| from typing import List, Dict, Any, Iterator, Optional, Tuple | |
| import re | |
| import time | |
| def load_onnx_model(model_path: str) -> Tuple[Any, ort.InferenceSession]: | |
| """ | |
| Load an ONNX model for text generation | |
| Args: | |
| model_path: Path to the ONNX model file | |
| Returns: | |
| Tuple of (model_info, session) | |
| """ | |
| try: | |
| # Configure ONNX runtime session options | |
| session_options = ort.SessionOptions() | |
| # Enable optimizations | |
| session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL | |
| # Set inter_op and intra_op threads for better performance | |
| session_options.inter_op_num_threads = 4 | |
| session_options.intra_op_num_threads = 4 | |
| # Create inference session | |
| session = ort.InferenceSession(model_path, session_options) | |
| # Get model info | |
| model_info = { | |
| "input_names": [input.name for input in session.get_inputs()], | |
| "output_names": [output.name for output in session.get_outputs()], | |
| "input_shapes": [input.shape for input in session.get_inputs()], | |
| "metadata": session.get_modelmeta() if hasattr(session, 'get_modelmeta') else {} | |
| } | |
| return model_info, session | |
| except Exception as e: | |
| raise Exception(f"Failed to load ONNX model from {model_path}: {str(e)}") | |
| def preprocess_text(text: str) -> str: | |
| """ | |
| Preprocess text for model input | |
| Args: | |
| text: Raw input text | |
| Returns: | |
| Preprocessed text | |
| """ | |
| # Basic text cleaning | |
| text = text.strip() | |
| # Remove extra whitespace | |
| text = re.sub(r'\s+', ' ', text) | |
| return text | |
| def postprocess_text(text: str) -> str: | |
| """ | |
| Postprocess model output | |
| Args: | |
| text: Raw model output | |
| Returns: | |
| Cleaned and formatted text | |
| """ | |
| if not text: | |
| return "" | |
| # Remove common artifacts | |
| text = text.strip() | |
| # Remove repeating whitespace | |
| text = re.sub(r'\s+', ' ', text) | |
| # Remove partial sentences at the end | |
| if text and not text.endswith(('.', '!', '?', '"', "'")): | |
| # Try to end at a reasonable punctuation | |
| sentences = re.split(r'[.!?]+', text) | |
| if len(sentences) > 1: | |
| text = '. '.join(sentences[:-1]) + '.' | |
| return text | |
| def setup_chat_prompt(conversation_history: List[str], current_message: str) -> str: | |
| """ | |
| Setup prompt for chat-based models | |
| Args: | |
| conversation_history: List of previous messages | |
| current_message: Current user message | |
| Returns: | |
| Formatted prompt for the model | |
| """ | |
| prompt = "" | |
| # Add conversation history | |
| for i, msg in enumerate(conversation_history): | |
| if i % 2 == 0: | |
| prompt += f"Human: {msg}\n" | |
| else: | |
| prompt += f"Assistant: {msg}\n" | |
| # Add current message | |
| prompt += f"Human: {current_message}\nAssistant:" | |
| return prompt | |
| def generate_response( | |
| session: ort.InferenceSession, | |
| prompt: str, | |
| max_length: int = 100, | |
| temperature: float = 0.7, | |
| top_p: float = 0.9, | |
| repetition_penalty: float = 1.1 | |
| ) -> Iterator[str]: | |
| """ | |
| Generate response using ONNX model with streaming | |
| Args: | |
| session: ONNX inference session | |
| prompt: Input prompt | |
| max_length: Maximum length of generated text | |
| temperature: Sampling temperature | |
| top_p: Top-p sampling parameter | |
| repetition_penalty: Repetition penalty | |
| Yields: | |
| Generated text chunks | |
| """ | |
| try: | |
| # Tokenize input (this is a simplified version - you'd need proper tokenization) | |
| input_tokens = tokenize_text(prompt) | |
| # Convert to numpy arrays | |
| input_ids = np.array([input_tokens], dtype=np.int64) | |
| # Prepare attention mask (assuming all tokens are valid) | |
| attention_mask = np.ones_like(input_ids) | |
| # For this example, we'll simulate generation | |
| # In a real implementation, you'd need to: | |
| # 1. Use proper tokenization | |
| # 2. Implement generation loop with sampling | |
| # 3. Handle model-specific requirements | |
| current_text = "" | |
| words = prompt.split() | |
| # Simulate streaming generation | |
| for i in range(min(max_length // 4, 20)): # Limit iterations | |
| # Simulate word generation | |
| if len(words) > 0: | |
| next_word = words[min(i, len(words)-1)] if i < len(words) else "continues" | |
| else: | |
| next_word = f"word_{i}" | |
| current_text += " " + next_word if current_text else next_word | |
| # Clean and yield | |
| cleaned_text = postprocess_text(current_text) | |
| if cleaned_text.strip(): | |
| yield cleaned_text | |
| time.sleep(0.05) # Simulate processing time | |
| # Stop if we've generated enough content | |
| if len(current_text.split()) >= 10: | |
| break | |
| except Exception as e: | |
| yield f"Error generating response: {str(e)}" | |
| def tokenize_text(text: str) -> List[int]: | |
| """ | |
| Simple tokenization for demonstration | |
| Note: In practice, you'd want to use the model's specific tokenizer | |
| Args: | |
| text: Input text | |
| Returns: | |
| List of token IDs | |
| """ | |
| # Simple character-based tokenization for demonstration | |
| # This is not suitable for real models - use proper tokenizers | |
| # Convert text to tokens (simple approach) | |
| tokens = [] | |
| for char in text.lower(): | |
| # Map common characters to token IDs | |
| if char.isalpha(): | |
| tokens.append(ord(char) - ord('a') + 1) | |
| elif char.isspace(): | |
| tokens.append(0) # Space token | |
| else: | |
| tokens.append(1) # Unknown token | |
| # Pad or truncate to a reasonable length | |
| max_length = 128 | |
| if len(tokens) > max_length: | |
| tokens = tokens[:max_length] | |
| else: | |
| tokens.extend([0] * (max_length - len(tokens))) | |
| return tokens | |
| def decode_tokens(tokens: List[int]) -> str: | |
| """ | |
| Decode token IDs back to text | |
| Args: | |
| tokens: List of token IDs | |
| Returns: | |
| Decoded text | |
| """ | |
| text = "" | |
| for token in tokens: | |
| if token == 0: | |
| text += " " | |
| elif 1 <= token <= 26: | |
| text += chr(ord('a') + token - 1) | |
| # Skip unknown tokens | |
| return text | |
| def sample_next_token( | |
| logits: np.ndarray, | |
| temperature: float = 0.7, | |
| top_p: float = 0.9 | |
| ) -> int: | |
| """ | |
| Sample next token from logits | |
| Args: | |
| logits: Model output logits | |
| temperature: Sampling temperature | |
| top_p: Top-p sampling parameter | |
| Returns: | |
| Selected token ID | |
| """ | |
| # Apply temperature | |
| if temperature > 0: | |
| logits = logits / temperature | |
| # Convert to probabilities | |
| probs = softmax(logits) | |
| # Apply top-p filtering | |
| if top_p < 1.0: | |
| sorted_probs = np.sort(probs)[::-1] | |
| cumulative_probs = np.cumsum(sorted_probs) | |
| # Find cutoff for top-p | |
| cutoff = 1.0 - top_p | |
| filtered_indices = np.where(cumulative_probs > cutoff)[0] | |
| if len(filtered_indices) > 0: | |
| probs[filtered_indices] = 0 | |
| probs = probs / np.sum(probs) # Renormalize | |
| # Sample from the distribution | |
| token_id = np.random.choice(len(probs), p=probs) | |
| return token_id | |
| def softmax(x: np.ndarray) -> np.ndarray: | |
| """Apply softmax function""" | |
| exp_x = np.exp(x - np.max(x)) # Numerical stability | |
| return exp_x / np.sum(exp_x) | |
| def calculate_model_performance(session: ort.InferenceSession) -> Dict[str, Any]: | |
| """ | |
| Calculate model performance metrics | |
| Args: | |
| session: ONNX inference session | |
| Returns: | |
| Dictionary with performance metrics | |
| """ | |
| metrics = {} | |
| try: | |
| # Get session info | |
| metrics["input_count"] = len(session.get_inputs()) | |
| metrics["output_count"] = len(session.get_outputs()) | |
| metrics["input_names"] = [input.name for input in session.get_inputs()] | |
| metrics["output_names"] = [output.name for output in session.get_outputs()] | |
| # Get provider information | |
| providers = session.get_providers() | |
| metrics["execution_providers"] = providers | |
| metrics["current_provider"] = providers[0] if providers else "Unknown" | |
| except Exception as e: | |
| metrics["error"] = str(e) | |
| return metrics | |
| This ONNX AI Chat application includes: | |
| ## Key Features: | |
| 1. **Modern Chat Interface**: Uses Gradio's `ChatInterface` for a clean, interactive chat experience | |
| 2. **ONNX Model Integration**: | |
| - Load ONNX models from file paths | |
| - Support for different ONNX models with proper session management | |
| - Performance optimizations for inference | |
| 3. **Configurable Generation Parameters**: | |
| - Max length, temperature, top-p, repetition penalty | |
| - Real-time parameter updates | |
| 4. **Robust Error Handling**: | |
| - Model loading validation | |
| - Generation error handling | |
| - User-friendly error messages | |
| 5. **Streaming Responses**: Incremental response generation for better user experience | |
| 6. **Professional UI**: | |
| - Custom CSS styling | |
| - Collapsible settings panel | |
| - Model status indicators | |
| - Built with anycoder attribution | |
| ## Usage: | |
| 1. **Load a Model**: Enter your ONNX model path in the settings panel | |
| 2. **Configure Parameters**: Adjust generation settings as needed | |
| 3. **Start Chatting**: Begin conversation with the AI model | |
| The application provides a complete foundation for ONNX-based text generation chat interfaces. You'll need to adapt the tokenization and generation logic for your specific model architecture. | |
| Note: The current implementation includes placeholder tokenization for demonstration. For production use, replace the tokenization functions with your model's specific tokenizer (e.g., GPT tokenizer, BERT tokenizer, etc.). |