Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from huggingface_hub import InferenceClient | |
| import json | |
| import os | |
| import shutil | |
| import pandas as pd | |
| from sklearn.feature_extraction.text import CountVectorizer | |
| from sklearn.ensemble import RandomForestClassifier | |
| from sklearn.pipeline import Pipeline | |
| import joblib | |
| import logging | |
| # --------------------------- | |
| # Logging Configuration | |
| # --------------------------- | |
| logging.basicConfig( | |
| filename='app.log', | |
| filemode='a', | |
| format='%(asctime)s - %(levelname)s - %(message)s', | |
| level=logging.INFO | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # --------------------------- | |
| # Initialize the HuggingFace API Client | |
| # --------------------------- | |
| # Replace 'gpt-3.5-turbo' with your desired model. Ensure you have the correct access. | |
| try: | |
| client = InferenceClient("gpt-3.5-turbo") | |
| logger.info("HuggingFace InferenceClient initialized successfully.") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize HuggingFace InferenceClient: {e}") | |
| raise | |
| # --------------------------- | |
| # Persistent Memory and Knowledge Base Setup | |
| # --------------------------- | |
| memory_file = "chat_memory.json" | |
| knowledge_base_dir = "knowledge_base" | |
| model_file = "chat_model.pkl" | |
| # Ensure directories exist | |
| os.makedirs(knowledge_base_dir, exist_ok=True) | |
| # --------------------------- | |
| # Memory Management Functions | |
| # --------------------------- | |
| def load_memory(): | |
| """Load conversation memory from a JSON file.""" | |
| try: | |
| if os.path.exists(memory_file): | |
| with open(memory_file, "r") as f: | |
| memory = json.load(f) | |
| logger.info("Conversation memory loaded successfully.") | |
| return memory | |
| logger.info("No existing conversation memory found. Starting fresh.") | |
| return [] | |
| except Exception as e: | |
| logger.error(f"Error loading memory: {e}") | |
| return [] | |
| def save_memory(memory): | |
| """Save conversation memory to a JSON file.""" | |
| try: | |
| with open(memory_file, "w") as f: | |
| json.dump(memory, f, indent=2) | |
| logger.info("Conversation memory saved successfully.") | |
| except Exception as e: | |
| logger.error(f"Error saving memory: {e}") | |
| def update_memory(message, response): | |
| """Append user message and assistant response to memory.""" | |
| try: | |
| memory = load_memory() | |
| memory.append({"role": "user", "content": message}) | |
| memory.append({"role": "assistant", "content": response}) | |
| # Optionally limit memory size | |
| if len(memory) > 1000: | |
| memory = memory[-1000:] | |
| save_memory(memory) | |
| except Exception as e: | |
| logger.error(f"Error updating memory: {e}") | |
| # --------------------------- | |
| # ML Model Management Functions | |
| # --------------------------- | |
| def load_or_initialize_model(): | |
| """Load the ML model from a file or initialize a new one.""" | |
| try: | |
| if os.path.exists(model_file): | |
| model = joblib.load(model_file) | |
| logger.info("ML model loaded successfully.") | |
| return model | |
| model = Pipeline([ | |
| ("vectorizer", CountVectorizer()), | |
| ("classifier", RandomForestClassifier(n_estimators=100, random_state=42)) | |
| ]) | |
| logger.info("Initialized new ML model pipeline.") | |
| return model | |
| except Exception as e: | |
| logger.error(f"Error loading or initializing model: {e}") | |
| raise | |
| def train_model_on_files(): | |
| """Train the ML model based on CSV files in the knowledge base.""" | |
| try: | |
| model = load_or_initialize_model() | |
| texts, labels = [], [] | |
| # Load data from the knowledge base | |
| for file_name in os.listdir(knowledge_base_dir): | |
| file_path = os.path.join(knowledge_base_dir, file_name) | |
| if file_path.endswith(".csv"): | |
| try: | |
| df = pd.read_csv(file_path) | |
| if "text" in df.columns and "label" in df.columns: | |
| texts.extend(df["text"].astype(str).tolist()) | |
| labels.extend(df["label"].astype(str).tolist()) | |
| logger.info(f"Loaded data from '{file_name}'.") | |
| else: | |
| logger.warning(f"File '{file_name}' is missing 'text' or 'label' columns.") | |
| return f"File '{file_name}' does not contain required 'text' and 'label' columns." | |
| except Exception as e: | |
| logger.error(f"Error reading '{file_name}': {e}") | |
| return f"Error reading '{file_name}': {str(e)}" | |
| if texts and labels: | |
| try: | |
| model.fit(texts, labels) | |
| joblib.dump(model, model_file) | |
| logger.info("ML model trained and saved successfully.") | |
| return f"Model trained on {len(texts)} samples from {len(os.listdir(knowledge_base_dir))} files." | |
| except Exception as e: | |
| logger.error(f"Error during model training: {e}") | |
| return f"Error during model training: {str(e)}" | |
| logger.warning("No valid training data found in the knowledge base.") | |
| return "No valid training data found in the knowledge base." | |
| except Exception as e: | |
| logger.error(f"Unexpected error in training model: {e}") | |
| return f"Unexpected error: {str(e)}" | |
| # --------------------------- | |
| # Chat Response Function | |
| # --------------------------- | |
| def respond(message, history, system_message, max_tokens, temperature, top_p): | |
| """ | |
| Generate a response to the user's message using the ML model or GPT model. | |
| Parameters: | |
| - message (str): User's input message. | |
| - history (list): Conversation history. | |
| - system_message (str): System prompt. | |
| - max_tokens (int): Maximum number of tokens for GPT response. | |
| - temperature (float): Sampling temperature for GPT. | |
| - top_p (float): Nucleus sampling parameter for GPT. | |
| Returns: | |
| - response (str): Generated response. | |
| """ | |
| try: | |
| # Attempt to get a prediction from the ML model | |
| model = load_or_initialize_model() | |
| pred_label = model.predict([message])[0] | |
| response = f"Predicted response: {pred_label}" | |
| update_memory(message, response) | |
| logger.info("Response generated using ML model.") | |
| return response | |
| except Exception as e: | |
| logger.info("ML model could not generate a response. Falling back to GPT model.") | |
| # Generate response using GPT | |
| try: | |
| messages = [{"role": "system", "content": system_message}] | |
| for turn in history: | |
| messages.append({"role": turn["role"], "content": turn["content"]}) | |
| messages.append({"role": "user", "content": message}) | |
| response = "" | |
| for message_part in client.chat_completion( | |
| messages, | |
| max_tokens=max_tokens, | |
| stream=True, | |
| temperature=temperature, | |
| top_p=top_p, | |
| ): | |
| token = message_part.get("choices", [{}])[0].get("delta", {}).get("content", "") | |
| response += token | |
| update_memory(message, response) | |
| logger.info("Response generated using GPT model.") | |
| return response | |
| except Exception as e: | |
| logger.error(f"Error generating response with GPT: {e}") | |
| response = f"Error generating response: {str(e)}" | |
| update_memory(message, response) | |
| return response | |
| # --------------------------- | |
| # Gradio Interface | |
| # --------------------------- | |
| def create_gradio_interface(): | |
| """Create and configure the Gradio interface.""" | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# π§ Advanced AI Chatbot with Knowledge Base and Model Training") | |
| # Chat Tab | |
| with gr.Tab("π¬ Chat"): | |
| chatbot = gr.Chatbot(label="AI Chatbot", type="messages") | |
| with gr.Row(): | |
| with gr.Column(scale=5): | |
| user_input = gr.Textbox( | |
| label="Your Message", | |
| placeholder="Type your message here...", | |
| lines=1 | |
| ) | |
| with gr.Column(scale=1, min_width=100): | |
| send_button = gr.Button("Send", variant="primary") | |
| with gr.Row(): | |
| system_message = gr.Textbox( | |
| value="You are an advanced AI Chatbot.", | |
| label="System Message", | |
| visible=False | |
| ) | |
| max_tokens = gr.Slider( | |
| minimum=100, maximum=2048, value=512, step=100, label="Max Tokens" | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature" | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.95, | |
| step=0.05, | |
| label="Top-p (Nucleus Sampling)", | |
| ) | |
| def handle_message(message, history, system_message, max_tokens, temperature, top_p): | |
| response = respond(message, history, system_message, max_tokens, temperature, top_p) | |
| history.append({"role": "user", "content": message}) | |
| history.append({"role": "assistant", "content": response}) | |
| return history, history | |
| send_button.click( | |
| handle_message, | |
| inputs=[user_input, chatbot, system_message, max_tokens, temperature, top_p], | |
| outputs=[chatbot, chatbot], | |
| ) | |
| user_input.submit( | |
| handle_message, | |
| inputs=[user_input, chatbot, system_message, max_tokens, temperature, top_p], | |
| outputs=[chatbot, chatbot], | |
| ) | |
| # Knowledge Base Tab | |
| with gr.Tab("π Knowledge Base"): | |
| gr.Markdown("### Manage Knowledge Base") | |
| file_upload = gr.File( | |
| label="Upload CSV File", | |
| file_types=[".csv"], | |
| file_count="single" # Allows only single file upload | |
| ) | |
| upload_output = gr.Textbox(label="Upload Result", interactive=False) | |
| train_button = gr.Button("π Train Model on Knowledge Base") | |
| train_output = gr.Textbox(label="Training Result", interactive=False) | |
| def upload_file(file): | |
| if not file: | |
| return "No file uploaded." | |
| try: | |
| # Determine file path and name | |
| if isinstance(file, dict): | |
| file_path = file.get('path', '') | |
| file_name = file.get('name', '') | |
| else: | |
| file_path = file | |
| file_name = os.path.basename(file_path) | |
| # Validate file extension | |
| if not file_name.endswith(".csv"): | |
| logger.warning(f"Invalid file type attempted: {file_name}") | |
| return "Invalid file type. Please upload a CSV file." | |
| # Save file to knowledge base directory | |
| destination_path = os.path.join(knowledge_base_dir, file_name) | |
| shutil.copy(file_path, destination_path) | |
| logger.info(f"File '{file_name}' uploaded successfully.") | |
| return f"File '{file_name}' uploaded successfully." | |
| except Exception as e: | |
| logger.error(f"Error uploading file: {e}") | |
| return f"Error uploading file: {str(e)}" | |
| file_upload.change(upload_file, inputs=file_upload, outputs=upload_output) | |
| train_button.click(train_model_on_files, inputs=None, outputs=train_output) | |
| # Memory Tab | |
| with gr.Tab("π§ Memory"): | |
| gr.Markdown("### View and Manage Conversation Memory") | |
| memory_display = gr.JSON(label="Conversation Memory") | |
| with gr.Row(): | |
| refresh_memory = gr.Button("π Refresh Memory") | |
| clear_memory = gr.Button("ποΈ Clear Memory") | |
| export_memory = gr.Button("π€ Export Memory") | |
| export_output = gr.File(label="Download Memory", visible=False) | |
| def display_memory(): | |
| return load_memory() | |
| def clear_memory_func(): | |
| try: | |
| save_memory([]) | |
| logger.info("Conversation memory cleared.") | |
| return [] | |
| except Exception as e: | |
| logger.error(f"Error clearing memory: {e}") | |
| return f"Error clearing memory: {str(e)}" | |
| def export_memory_func(): | |
| if os.path.exists(memory_file): | |
| return memory_file # Gradio will handle the download | |
| return "No memory file found." | |
| refresh_memory.click(display_memory, inputs=None, outputs=memory_display) | |
| clear_memory.click(clear_memory_func, inputs=None, outputs=memory_display) | |
| export_memory.click(export_memory_func, inputs=None, outputs=export_output) | |
| # Download Model Tab | |
| with gr.Tab("πΎ Download Model"): | |
| gr.Markdown("### Download the Trained Model") | |
| download_button = gr.Button("π₯ Download Model") | |
| model_download_output = gr.File(label="Downloadable Model") | |
| def download_model(): | |
| if os.path.exists(model_file): | |
| return model_file # Gradio will handle the file download | |
| return "No trained model found." | |
| download_button.click(download_model, inputs=None, outputs=model_download_output) | |
| # Settings Tab | |
| with gr.Tab("βοΈ Settings"): | |
| gr.Markdown("### Application Settings") | |
| gr.Textbox( | |
| value="", | |
| label="Settings Placeholder", | |
| placeholder="Add settings here..." | |
| # Removed 'interactive' parameter as it's unsupported | |
| ) | |
| return demo | |
| # --------------------------- | |
| # Main Execution | |
| # --------------------------- | |
| if __name__ == "__main__": | |
| try: | |
| interface = create_gradio_interface() | |
| logger.info("Launching Gradio interface.") | |
| interface.launch() | |
| except Exception as e: | |
| logger.critical(f"Application failed to start: {e}") | |