Spaces:
Runtime error
Runtime error
| import pandas as pd | |
| import numpy as np | |
| import streamlit as st | |
| import tensorflow as tf | |
| from tensorflow.keras.preprocessing.text import tokenizer_from_json | |
| from tensorflow.keras.preprocessing.sequence import pad_sequences | |
| import json | |
| # Define constants for file paths | |
| FULL_MODEL_PATH = 'full_model.h5' | |
| FINETUNE_WEIGHTS_PATH = 'fine_tuned_model_weights.h5' | |
| TOKENIZER_PATH = 'tokenizer.json' | |
| DATA_PATH = 'customer_agent (1).csv' | |
| # Function to load the full model and fine-tune weights | |
| def load_model_and_weights(): | |
| try: | |
| # Load the full model | |
| model = tf.keras.models.load_model(FULL_MODEL_PATH) | |
| # Load fine-tuned weights into the full model | |
| model.load_weights(FINETUNE_WEIGHTS_PATH) | |
| # Load tokenizer | |
| with open(TOKENIZER_PATH, 'r') as f: | |
| tokenizer_data = json.load(f) | |
| tokenizer = tokenizer_from_json(tokenizer_data) | |
| return model, tokenizer | |
| except Exception as e: | |
| st.error(f"Error loading model or tokenizer: {e}") | |
| return None, None | |
| # Initialize model and tokenizer | |
| model, tokenizer = load_model_and_weights() | |
| # Load the dataset | |
| data = pd.read_csv(DATA_PATH) | |
| def preprocess_text(text): | |
| sequences = tokenizer.texts_to_sequences([text]) | |
| padded_sequences = pad_sequences(sequences, maxlen=100) # Adjust maxlen as needed | |
| return padded_sequences | |
| def get_response(query): | |
| if model and tokenizer: | |
| preprocessed_query = preprocess_text(query) | |
| prediction = model.predict(preprocessed_query) | |
| response_index = np.argmax(prediction) | |
| # Ensure the index is valid | |
| if response_index < len(data): | |
| response = data.loc[response_index, 'Agent_Response'] # Adjust column name as necessary | |
| return response | |
| else: | |
| return "Sorry, I couldn't find a suitable response." | |
| return "Model or tokenizer not loaded properly." | |
| def main(): | |
| st.title("Customer Service Chatbot") | |
| if 'chat_history' not in st.session_state: | |
| st.session_state.chat_history = [] | |
| selected_department = st.selectbox("Select Department", ["customer_service", "billing", "orders", "technical", "hardware", "software"]) | |
| user_input = st.text_input("Enter your message:") | |
| if st.button("Send"): | |
| if user_input: | |
| response = get_response(user_input) | |
| st.session_state.chat_history.append({"role": "user", "content": user_input}) | |
| st.session_state.chat_history.append({"role": "assistant", "content": response}) | |
| st.write("Chat History:") | |
| for message in st.session_state.chat_history: | |
| st.write(f"{message['role'].capitalize()}: {message['content']}") | |
| if __name__ == "__main__": | |
| main() | |