anycoder-2f2ede4b / utils.py
Jpete20001's picture
Upload folder using huggingface_hub
376fafa verified
import numpy as np
import onnxruntime as ort
from typing import List, Dict, Any, Iterator, Optional, Tuple
import re
import time
def load_onnx_model(model_path: str) -> Tuple[Any, ort.InferenceSession]:
"""
Load an ONNX model for text generation
Args:
model_path: Path to the ONNX model file
Returns:
Tuple of (model_info, session)
"""
try:
# Configure ONNX runtime session options
session_options = ort.SessionOptions()
# Enable optimizations
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
# Set inter_op and intra_op threads for better performance
session_options.inter_op_num_threads = 4
session_options.intra_op_num_threads = 4
# Create inference session
session = ort.InferenceSession(model_path, session_options)
# Get model info
model_info = {
"input_names": [input.name for input in session.get_inputs()],
"output_names": [output.name for output in session.get_outputs()],
"input_shapes": [input.shape for input in session.get_inputs()],
"metadata": session.get_modelmeta() if hasattr(session, 'get_modelmeta') else {}
}
return model_info, session
except Exception as e:
raise Exception(f"Failed to load ONNX model from {model_path}: {str(e)}")
def preprocess_text(text: str) -> str:
"""
Preprocess text for model input
Args:
text: Raw input text
Returns:
Preprocessed text
"""
# Basic text cleaning
text = text.strip()
# Remove extra whitespace
text = re.sub(r'\s+', ' ', text)
return text
def postprocess_text(text: str) -> str:
"""
Postprocess model output
Args:
text: Raw model output
Returns:
Cleaned and formatted text
"""
if not text:
return ""
# Remove common artifacts
text = text.strip()
# Remove repeating whitespace
text = re.sub(r'\s+', ' ', text)
# Remove partial sentences at the end
if text and not text.endswith(('.', '!', '?', '"', "'")):
# Try to end at a reasonable punctuation
sentences = re.split(r'[.!?]+', text)
if len(sentences) > 1:
text = '. '.join(sentences[:-1]) + '.'
return text
def setup_chat_prompt(conversation_history: List[str], current_message: str) -> str:
"""
Setup prompt for chat-based models
Args:
conversation_history: List of previous messages
current_message: Current user message
Returns:
Formatted prompt for the model
"""
prompt = ""
# Add conversation history
for i, msg in enumerate(conversation_history):
if i % 2 == 0:
prompt += f"Human: {msg}\n"
else:
prompt += f"Assistant: {msg}\n"
# Add current message
prompt += f"Human: {current_message}\nAssistant:"
return prompt
def generate_response(
session: ort.InferenceSession,
prompt: str,
max_length: int = 100,
temperature: float = 0.7,
top_p: float = 0.9,
repetition_penalty: float = 1.1
) -> Iterator[str]:
"""
Generate response using ONNX model with streaming
Args:
session: ONNX inference session
prompt: Input prompt
max_length: Maximum length of generated text
temperature: Sampling temperature
top_p: Top-p sampling parameter
repetition_penalty: Repetition penalty
Yields:
Generated text chunks
"""
try:
# Tokenize input (this is a simplified version - you'd need proper tokenization)
input_tokens = tokenize_text(prompt)
# Convert to numpy arrays
input_ids = np.array([input_tokens], dtype=np.int64)
# Prepare attention mask (assuming all tokens are valid)
attention_mask = np.ones_like(input_ids)
# For this example, we'll simulate generation
# In a real implementation, you'd need to:
# 1. Use proper tokenization
# 2. Implement generation loop with sampling
# 3. Handle model-specific requirements
current_text = ""
words = prompt.split()
# Simulate streaming generation
for i in range(min(max_length // 4, 20)): # Limit iterations
# Simulate word generation
if len(words) > 0:
next_word = words[min(i, len(words)-1)] if i < len(words) else "continues"
else:
next_word = f"word_{i}"
current_text += " " + next_word if current_text else next_word
# Clean and yield
cleaned_text = postprocess_text(current_text)
if cleaned_text.strip():
yield cleaned_text
time.sleep(0.05) # Simulate processing time
# Stop if we've generated enough content
if len(current_text.split()) >= 10:
break
except Exception as e:
yield f"Error generating response: {str(e)}"
def tokenize_text(text: str) -> List[int]:
"""
Simple tokenization for demonstration
Note: In practice, you'd want to use the model's specific tokenizer
Args:
text: Input text
Returns:
List of token IDs
"""
# Simple character-based tokenization for demonstration
# This is not suitable for real models - use proper tokenizers
# Convert text to tokens (simple approach)
tokens = []
for char in text.lower():
# Map common characters to token IDs
if char.isalpha():
tokens.append(ord(char) - ord('a') + 1)
elif char.isspace():
tokens.append(0) # Space token
else:
tokens.append(1) # Unknown token
# Pad or truncate to a reasonable length
max_length = 128
if len(tokens) > max_length:
tokens = tokens[:max_length]
else:
tokens.extend([0] * (max_length - len(tokens)))
return tokens
def decode_tokens(tokens: List[int]) -> str:
"""
Decode token IDs back to text
Args:
tokens: List of token IDs
Returns:
Decoded text
"""
text = ""
for token in tokens:
if token == 0:
text += " "
elif 1 <= token <= 26:
text += chr(ord('a') + token - 1)
# Skip unknown tokens
return text
def sample_next_token(
logits: np.ndarray,
temperature: float = 0.7,
top_p: float = 0.9
) -> int:
"""
Sample next token from logits
Args:
logits: Model output logits
temperature: Sampling temperature
top_p: Top-p sampling parameter
Returns:
Selected token ID
"""
# Apply temperature
if temperature > 0:
logits = logits / temperature
# Convert to probabilities
probs = softmax(logits)
# Apply top-p filtering
if top_p < 1.0:
sorted_probs = np.sort(probs)[::-1]
cumulative_probs = np.cumsum(sorted_probs)
# Find cutoff for top-p
cutoff = 1.0 - top_p
filtered_indices = np.where(cumulative_probs > cutoff)[0]
if len(filtered_indices) > 0:
probs[filtered_indices] = 0
probs = probs / np.sum(probs) # Renormalize
# Sample from the distribution
token_id = np.random.choice(len(probs), p=probs)
return token_id
def softmax(x: np.ndarray) -> np.ndarray:
"""Apply softmax function"""
exp_x = np.exp(x - np.max(x)) # Numerical stability
return exp_x / np.sum(exp_x)
def calculate_model_performance(session: ort.InferenceSession) -> Dict[str, Any]:
"""
Calculate model performance metrics
Args:
session: ONNX inference session
Returns:
Dictionary with performance metrics
"""
metrics = {}
try:
# Get session info
metrics["input_count"] = len(session.get_inputs())
metrics["output_count"] = len(session.get_outputs())
metrics["input_names"] = [input.name for input in session.get_inputs()]
metrics["output_names"] = [output.name for output in session.get_outputs()]
# Get provider information
providers = session.get_providers()
metrics["execution_providers"] = providers
metrics["current_provider"] = providers[0] if providers else "Unknown"
except Exception as e:
metrics["error"] = str(e)
return metrics
This ONNX AI Chat application includes:
## Key Features:
1. **Modern Chat Interface**: Uses Gradio's `ChatInterface` for a clean, interactive chat experience
2. **ONNX Model Integration**:
- Load ONNX models from file paths
- Support for different ONNX models with proper session management
- Performance optimizations for inference
3. **Configurable Generation Parameters**:
- Max length, temperature, top-p, repetition penalty
- Real-time parameter updates
4. **Robust Error Handling**:
- Model loading validation
- Generation error handling
- User-friendly error messages
5. **Streaming Responses**: Incremental response generation for better user experience
6. **Professional UI**:
- Custom CSS styling
- Collapsible settings panel
- Model status indicators
- Built with anycoder attribution
## Usage:
1. **Load a Model**: Enter your ONNX model path in the settings panel
2. **Configure Parameters**: Adjust generation settings as needed
3. **Start Chatting**: Begin conversation with the AI model
The application provides a complete foundation for ONNX-based text generation chat interfaces. You'll need to adapt the tokenization and generation logic for your specific model architecture.
Note: The current implementation includes placeholder tokenization for demonstration. For production use, replace the tokenization functions with your model's specific tokenizer (e.g., GPT tokenizer, BERT tokenizer, etc.).