Narayana02's picture
Update app.py
09bc1b5 verified
import pandas as pd
import numpy as np
import streamlit as st
import tensorflow as tf
from tensorflow.keras.preprocessing.text import tokenizer_from_json
from tensorflow.keras.preprocessing.sequence import pad_sequences
import json
# Define constants for file paths
FULL_MODEL_PATH = 'full_model.h5'
FINETUNE_WEIGHTS_PATH = 'fine_tuned_model_weights.h5'
TOKENIZER_PATH = 'tokenizer.json'
DATA_PATH = 'customer_agent (1).csv'
# Function to load the full model and fine-tune weights
def load_model_and_weights():
try:
# Load the full model
model = tf.keras.models.load_model(FULL_MODEL_PATH)
# Load fine-tuned weights into the full model
model.load_weights(FINETUNE_WEIGHTS_PATH)
# Load tokenizer
with open(TOKENIZER_PATH, 'r') as f:
tokenizer_data = json.load(f)
tokenizer = tokenizer_from_json(tokenizer_data)
return model, tokenizer
except Exception as e:
st.error(f"Error loading model or tokenizer: {e}")
return None, None
# Initialize model and tokenizer
model, tokenizer = load_model_and_weights()
# Load the dataset
data = pd.read_csv(DATA_PATH)
def preprocess_text(text):
sequences = tokenizer.texts_to_sequences([text])
padded_sequences = pad_sequences(sequences, maxlen=100) # Adjust maxlen as needed
return padded_sequences
def get_response(query):
if model and tokenizer:
preprocessed_query = preprocess_text(query)
prediction = model.predict(preprocessed_query)
response_index = np.argmax(prediction)
# Ensure the index is valid
if response_index < len(data):
response = data.loc[response_index, 'Agent_Response'] # Adjust column name as necessary
return response
else:
return "Sorry, I couldn't find a suitable response."
return "Model or tokenizer not loaded properly."
def main():
st.title("Customer Service Chatbot")
if 'chat_history' not in st.session_state:
st.session_state.chat_history = []
selected_department = st.selectbox("Select Department", ["customer_service", "billing", "orders", "technical", "hardware", "software"])
user_input = st.text_input("Enter your message:")
if st.button("Send"):
if user_input:
response = get_response(user_input)
st.session_state.chat_history.append({"role": "user", "content": user_input})
st.session_state.chat_history.append({"role": "assistant", "content": response})
st.write("Chat History:")
for message in st.session_state.chat_history:
st.write(f"{message['role'].capitalize()}: {message['content']}")
if __name__ == "__main__":
main()