import gradio as gr import numpy as np import onnxruntime as ort import re import threading import time from typing import List, Dict, Any, Optional from utils import ( load_onnx_model, generate_response, preprocess_text, postprocess_text, setup_chat_prompt ) # Global variables for model and session onnx_model = None session = None model_config = { "max_length": 100, "temperature": 0.7, "top_p": 0.9, "repetition_penalty": 1.1 } def initialize_model(model_path: str = None): """Initialize the ONNX model""" global onnx_model, session try: if model_path: onnx_model, session = load_onnx_model(model_path) return f"✅ Successfully loaded custom model from: {model_path}" else: # Try to load a default model (this is a placeholder - you'd need actual ONNX models) return "ℹ️ Please provide a valid ONNX model path to start chatting" except Exception as e: return f"❌ Error loading model: {str(e)}" def chat_response(message: str, history: List[List[str]], model_path: str = "", use_context: bool = True): """Generate chat response using ONNX model""" global session, onnx_model # Check if model is loaded if session is None: if model_path: try: onnx_model, session = load_onnx_model(model_path) except Exception as e: yield "❌ Failed to load model. Please check the model path." return else: yield "❌ Please load a model first by providing the ONNX model path in settings." return try: # Prepare conversation history if use_context and history: conversation = "" for msg in history: if len(msg) >= 2: conversation += f"Human: {msg[0]}\nAssistant: {msg[1]}\n" conversation += f"Human: {message}\nAssistant:" prompt = conversation else: prompt = f"Human: {message}\nAssistant:" # Preprocess the prompt processed_prompt = preprocess_text(prompt) # Generate response with streaming full_response = "" for chunk in generate_response(session, processed_prompt, **model_config): full_response = chunk # Clean and format the response cleaned_response = postprocess_text(chunk) yield cleaned_response # Small delay for better UX time.sleep(0.01) except Exception as e: yield f"❌ Error generating response: {str(e)}" def update_model_config(max_length: int, temperature: float, top_p: float, repetition_penalty: float): """Update generation parameters""" global model_config model_config.update({ "max_length": max_length, "temperature": temperature, "top_p": top_p, "repetition_penalty": repetition_penalty }) def clear_chat(): """Clear chat history""" return [] def load_model_api(model_path: str): """API for loading model""" global session if not model_path.strip(): return "❌ Please provide a valid ONNX model path." message = initialize_model(model_path.strip()) return message # Create the Gradio interface def create_app(): """Create and configure the Gradio application""" # Custom CSS for better styling css = """ .chatbot-container { max-width: 1200px; margin: 0 auto; } .header-text { text-align: center; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent; background-clip: text; font-size: 2.5em; font-weight: bold; margin-bottom: 10px; } .subtitle-text { text-align: center; color: #666; margin-bottom: 30px; font-size: 1.1em; } .model-status { padding: 10px; border-radius: 8px; margin-bottom: 20px; text-align: center; } .model-loaded { background-color: #d4edda; border: 1px solid #c3e6cb; color: #155724; } .model-not-loaded { background-color: #f8d7da; border: 1px solid #f5c6cb; color: #721c24; } """ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: # Header gr.HTML("""
🤖 ONNX AI Chat
Chat with AI models using ONNX runtime
Built with anycoder
""") # Model status indicator model_status = gr.HTML( '
❌ No model loaded - Please load a model to start chatting
' ) # Settings panel with gr.Accordion("⚙️ Model Settings & Configuration", open=False): model_path_input = gr.Textbox( label="ONNX Model Path", placeholder="Enter the path to your ONNX model file...", info="Provide the path to a valid ONNX model for text generation" ) load_model_btn = gr.Button("🔄 Load Model", variant="primary") model_load_status = gr.Textbox(label="Model Load Status", interactive=False) # Generation parameters with gr.Row(): max_length = gr.Slider(10, 500, value=100, step=10, label="Max Length") temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature") with gr.Row(): top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top P") repetition_penalty = gr.Slider(0.5, 2.0, value=1.1, step=0.05, label="Repetition Penalty") update_config_btn = gr.Button("🔧 Update Settings", variant="secondary") # Connect config updates update_config_btn.click( update_model_config, inputs=[max_length, temperature, top_p, repetition_penalty], outputs=[] ) # Chat interface chatbot = gr.ChatInterface( fn=chat_response, title="💬 Chat with AI", description="Start a conversation! Load a model first to begin chatting.", retry_btn="🔄 Retry", undo_btn="↩️ Undo", clear_btn="🗑️ Clear", additional_inputs=[model_path_input], additional_inputs_accordion_id="model_accordion" ) # Connect model loading load_model_btn.click( load_model_api, inputs=[model_path_input], outputs=[model_load_status] ).then( lambda status: status, inputs=[model_load_status], outputs=[model_status] ) # Clear chat functionality chatbot.clear_btn.click( clear_chat, outputs=[chatbot.chatbot_state] ) return demo if __name__ == "__main__": # Create and launch the app app = create_app() # Launch with appropriate settings app.launch( server_name="0.0.0.0", server_port=7860, share=False, show_error=True, quiet=False )