Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import numpy as np | |
| import onnxruntime as ort | |
| import re | |
| import threading | |
| import time | |
| from typing import List, Dict, Any, Optional | |
| from utils import ( | |
| load_onnx_model, | |
| generate_response, | |
| preprocess_text, | |
| postprocess_text, | |
| setup_chat_prompt | |
| ) | |
| # Global variables for model and session | |
| onnx_model = None | |
| session = None | |
| model_config = { | |
| "max_length": 100, | |
| "temperature": 0.7, | |
| "top_p": 0.9, | |
| "repetition_penalty": 1.1 | |
| } | |
| def initialize_model(model_path: str = None): | |
| """Initialize the ONNX model""" | |
| global onnx_model, session | |
| try: | |
| if model_path: | |
| onnx_model, session = load_onnx_model(model_path) | |
| return f"β Successfully loaded custom model from: {model_path}" | |
| else: | |
| # Try to load a default model (this is a placeholder - you'd need actual ONNX models) | |
| return "βΉοΈ Please provide a valid ONNX model path to start chatting" | |
| except Exception as e: | |
| return f"β Error loading model: {str(e)}" | |
| def chat_response(message: str, history: List[List[str]], model_path: str = "", use_context: bool = True): | |
| """Generate chat response using ONNX model""" | |
| global session, onnx_model | |
| # Check if model is loaded | |
| if session is None: | |
| if model_path: | |
| try: | |
| onnx_model, session = load_onnx_model(model_path) | |
| except Exception as e: | |
| yield "β Failed to load model. Please check the model path." | |
| return | |
| else: | |
| yield "β Please load a model first by providing the ONNX model path in settings." | |
| return | |
| try: | |
| # Prepare conversation history | |
| if use_context and history: | |
| conversation = "" | |
| for msg in history: | |
| if len(msg) >= 2: | |
| conversation += f"Human: {msg[0]}\nAssistant: {msg[1]}\n" | |
| conversation += f"Human: {message}\nAssistant:" | |
| prompt = conversation | |
| else: | |
| prompt = f"Human: {message}\nAssistant:" | |
| # Preprocess the prompt | |
| processed_prompt = preprocess_text(prompt) | |
| # Generate response with streaming | |
| full_response = "" | |
| for chunk in generate_response(session, processed_prompt, **model_config): | |
| full_response = chunk | |
| # Clean and format the response | |
| cleaned_response = postprocess_text(chunk) | |
| yield cleaned_response | |
| # Small delay for better UX | |
| time.sleep(0.01) | |
| except Exception as e: | |
| yield f"β Error generating response: {str(e)}" | |
| def update_model_config(max_length: int, temperature: float, top_p: float, repetition_penalty: float): | |
| """Update generation parameters""" | |
| global model_config | |
| model_config.update({ | |
| "max_length": max_length, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "repetition_penalty": repetition_penalty | |
| }) | |
| def clear_chat(): | |
| """Clear chat history""" | |
| return [] | |
| def load_model_api(model_path: str): | |
| """API for loading model""" | |
| global session | |
| if not model_path.strip(): | |
| return "β Please provide a valid ONNX model path." | |
| message = initialize_model(model_path.strip()) | |
| return message | |
| # Create the Gradio interface | |
| def create_app(): | |
| """Create and configure the Gradio application""" | |
| # Custom CSS for better styling | |
| css = """ | |
| .chatbot-container { | |
| max-width: 1200px; | |
| margin: 0 auto; | |
| } | |
| .header-text { | |
| text-align: center; | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| -webkit-background-clip: text; | |
| -webkit-text-fill-color: transparent; | |
| background-clip: text; | |
| font-size: 2.5em; | |
| font-weight: bold; | |
| margin-bottom: 10px; | |
| } | |
| .subtitle-text { | |
| text-align: center; | |
| color: #666; | |
| margin-bottom: 30px; | |
| font-size: 1.1em; | |
| } | |
| .model-status { | |
| padding: 10px; | |
| border-radius: 8px; | |
| margin-bottom: 20px; | |
| text-align: center; | |
| } | |
| .model-loaded { | |
| background-color: #d4edda; | |
| border: 1px solid #c3e6cb; | |
| color: #155724; | |
| } | |
| .model-not-loaded { | |
| background-color: #f8d7da; | |
| border: 1px solid #f5c6cb; | |
| color: #721c24; | |
| } | |
| """ | |
| with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: | |
| # Header | |
| gr.HTML(""" | |
| <div class="header-text">π€ ONNX AI Chat</div> | |
| <div class="subtitle-text">Chat with AI models using ONNX runtime</div> | |
| <div style="text-align: center; margin-bottom: 20px;"> | |
| <span>Built with <a href="https://huggingface.co/spaces/akhaliq/anycoder" target="_blank">anycoder</a></span> | |
| </div> | |
| """) | |
| # Model status indicator | |
| model_status = gr.HTML( | |
| '<div class="model-status model-not-loaded">β No model loaded - Please load a model to start chatting</div>' | |
| ) | |
| # Settings panel | |
| with gr.Accordion("βοΈ Model Settings & Configuration", open=False): | |
| model_path_input = gr.Textbox( | |
| label="ONNX Model Path", | |
| placeholder="Enter the path to your ONNX model file...", | |
| info="Provide the path to a valid ONNX model for text generation" | |
| ) | |
| load_model_btn = gr.Button("π Load Model", variant="primary") | |
| model_load_status = gr.Textbox(label="Model Load Status", interactive=False) | |
| # Generation parameters | |
| with gr.Row(): | |
| max_length = gr.Slider(10, 500, value=100, step=10, label="Max Length") | |
| temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature") | |
| with gr.Row(): | |
| top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top P") | |
| repetition_penalty = gr.Slider(0.5, 2.0, value=1.1, step=0.05, label="Repetition Penalty") | |
| update_config_btn = gr.Button("π§ Update Settings", variant="secondary") | |
| # Connect config updates | |
| update_config_btn.click( | |
| update_model_config, | |
| inputs=[max_length, temperature, top_p, repetition_penalty], | |
| outputs=[] | |
| ) | |
| # Chat interface | |
| chatbot = gr.ChatInterface( | |
| fn=chat_response, | |
| title="π¬ Chat with AI", | |
| description="Start a conversation! Load a model first to begin chatting.", | |
| retry_btn="π Retry", | |
| undo_btn="β©οΈ Undo", | |
| clear_btn="ποΈ Clear", | |
| additional_inputs=[model_path_input], | |
| additional_inputs_accordion_id="model_accordion" | |
| ) | |
| # Connect model loading | |
| load_model_btn.click( | |
| load_model_api, | |
| inputs=[model_path_input], | |
| outputs=[model_load_status] | |
| ).then( | |
| lambda status: status, | |
| inputs=[model_load_status], | |
| outputs=[model_status] | |
| ) | |
| # Clear chat functionality | |
| chatbot.clear_btn.click( | |
| clear_chat, | |
| outputs=[chatbot.chatbot_state] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| # Create and launch the app | |
| app = create_app() | |
| # Launch with appropriate settings | |
| app.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_error=True, | |
| quiet=False | |
| ) | |