import gradio as gr import numpy as np import onnxruntime as ort import re import threading import time from typing import List, Dict, Any, Optional from utils import ( load_onnx_model, generate_response, preprocess_text, postprocess_text, setup_chat_prompt ) # Global variables for model and session onnx_model = None session = None model_config = { "max_length": 100, "temperature": 0.7, "top_p": 0.9, "repetition_penalty": 1.1 } def initialize_model(model_path: str = None): """Initialize the ONNX model""" global onnx_model, session try: if model_path: onnx_model, session = load_onnx_model(model_path) return f"✅ Successfully loaded custom model from: {model_path}" else: # Try to load a default model (this is a placeholder - you'd need actual ONNX models) return "ℹ️ Please provide a valid ONNX model path to start chatting" except Exception as e: return f"❌ Error loading model: {str(e)}" def chat_response(message: str, history: List[List[str]], model_path: str = "", use_context: bool = True): """Generate chat response using ONNX model""" global session, onnx_model # Check if model is loaded if session is None: if model_path: try: onnx_model, session = load_onnx_model(model_path) except Exception as e: yield "❌ Failed to load model. Please check the model path." return else: yield "❌ Please load a model first by providing the ONNX model path in settings." return try: # Prepare conversation history if use_context and history: conversation = "" for msg in history: if len(msg) >= 2: conversation += f"Human: {msg[0]}\nAssistant: {msg[1]}\n" conversation += f"Human: {message}\nAssistant:" prompt = conversation else: prompt = f"Human: {message}\nAssistant:" # Preprocess the prompt processed_prompt = preprocess_text(prompt) # Generate response with streaming full_response = "" for chunk in generate_response(session, processed_prompt, **model_config): full_response = chunk # Clean and format the response cleaned_response = postprocess_text(chunk) yield cleaned_response # Small delay for better UX time.sleep(0.01) except Exception as e: yield f"❌ Error generating response: {str(e)}" def update_model_config(max_length: int, temperature: float, top_p: float, repetition_penalty: float): """Update generation parameters""" global model_config model_config.update({ "max_length": max_length, "temperature": temperature, "top_p": top_p, "repetition_penalty": repetition_penalty }) def clear_chat(): """Clear chat history""" return [] def load_model_api(model_path: str): """API for loading model""" global session if not model_path.strip(): return "❌ Please provide a valid ONNX model path." message = initialize_model(model_path.strip()) return message # Create the Gradio interface def create_app(): """Create and configure the Gradio application""" # Custom CSS for better styling css = """ .chatbot-container { max-width: 1200px; margin: 0 auto; } .header-text { text-align: center; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent; background-clip: text; font-size: 2.5em; font-weight: bold; margin-bottom: 10px; } .subtitle-text { text-align: center; color: #666; margin-bottom: 30px; font-size: 1.1em; } .model-status { padding: 10px; border-radius: 8px; margin-bottom: 20px; text-align: center; } .model-loaded { background-color: #d4edda; border: 1px solid #c3e6cb; color: #155724; } .model-not-loaded { background-color: #f8d7da; border: 1px solid #f5c6cb; color: #721c24; } """ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: # Header gr.HTML("""