Spaces:
Runtime error
Runtime error
| # Import libraries | |
| import pandas as pd | |
| from fastapi import FastAPI | |
| from fastapi.responses import HTMLResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from sentence_transformers import SentenceTransformer | |
| import faiss | |
| from datasets import load_dataset | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| import gradio as gr | |
| # Load the Dataset from Hugging Face and FAQ CSV | |
| support_data = load_dataset("rjac/e-commerce-customer-support-qa") | |
| # Load FAQ data from a local CSV file directly | |
| faq_data = pd.read_csv("Ecommerce_FAQs.csv") | |
| # Preprocess and Clean Data | |
| faq_data.rename(columns={'prompt': 'Question', 'response': 'Answer'}, inplace=True) | |
| faq_data = faq_data[['Question', 'Answer']] | |
| support_data_df = pd.DataFrame(support_data['train']) | |
| # Extract question-answer pairs from the conversation field | |
| def extract_conversation(data): | |
| try: | |
| parts = data.split("\n\n") | |
| question = parts[1].split(": ", 1)[1] if len(parts) > 1 else "" | |
| answer = parts[2].split(": ", 1)[1] if len(parts) > 2 else "" | |
| return pd.Series({"Question": question, "Answer": answer}) | |
| except IndexError: | |
| return pd.Series({"Question": "", "Answer": ""}) | |
| # Apply extraction function | |
| support_data_df[['Question', 'Answer']] = support_data_df['conversation'].apply(extract_conversation) | |
| # Combine FAQ data with support data | |
| combined_data = pd.concat([faq_data, support_data_df[['Question', 'Answer']]], ignore_index=True) | |
| # Initialize SBERT Model | |
| model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2') | |
| # Generate and Index Embeddings for Combined Data | |
| questions = combined_data['Question'].tolist() | |
| embeddings = model.encode(questions, convert_to_tensor=True) | |
| # Create FAISS index | |
| index = faiss.IndexFlatL2(embeddings.shape[1]) | |
| index.add(embeddings.cpu().numpy()) | |
| # Load your fine-tuned DialoGPT model and tokenizer | |
| tokenizer_gpt = AutoTokenizer.from_pretrained("Mishal23/fine_tuned_dialoGPT_model") # Update with your fine-tuned model path | |
| model_gpt = AutoModelForCausalLM.from_pretrained("Mishal23/fine_tuned_dialoGPT_model") # Update with your fine-tuned model path | |
| # Define Retrieval Function | |
| def retrieve_answer(question): | |
| question_embedding = model.encode([question], convert_to_tensor=True) | |
| question_embedding_np = question_embedding.cpu().numpy() | |
| _, closest_index = index.search(question_embedding_np, k=1) | |
| best_match_idx = closest_index[0][0] | |
| answer = combined_data.iloc[best_match_idx]['Answer'] | |
| # If the answer is empty, generate a fallback response | |
| if answer.strip() == "": | |
| return generate_response(question) # Generate a response from DialoGPT | |
| return answer | |
| # Generate response using your fine-tuned DialoGPT model | |
| def generate_response(user_input): | |
| input_ids = tokenizer_gpt.encode(user_input, return_tensors='pt') | |
| chat_history_ids = model_gpt.generate(input_ids, max_length=100, pad_token_id=tokenizer_gpt.eos_token_id) | |
| response = tokenizer_gpt.decode(chat_history_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True) | |
| return response if response.strip() else "Oops, I don't know the answer to that." | |
| # Initialize FastAPI app | |
| app = FastAPI() | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Allows all origins | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Define FastAPI route for Gradio interface | |
| async def read_root(): | |
| return HTMLResponse("""<html> | |
| <head> | |
| <title>E-commerce Support Chatbot</title> | |
| </head> | |
| <body> | |
| <h1>Welcome to the E-commerce Support Chatbot</h1> | |
| <p>Use the Gradio interface to chat with the bot!</p> | |
| </body> | |
| </html>""") | |
| # Gradio Chat Interface for E-commerce Support Chatbot | |
| def chatbot_interface(user_input, chat_history=[]): | |
| # Retrieve response from the knowledge base or generate it | |
| response = retrieve_answer(user_input) | |
| chat_history.append(("User", user_input)) | |
| chat_history.append(("Bot", response)) | |
| # Format chat history for display | |
| chat_display = [] | |
| for sender, message in chat_history: | |
| if sender == "User": | |
| chat_display.append(f"**You**: {message}") | |
| else: | |
| chat_display.append(f"**Bot**: {message}") | |
| return "\n\n".join(chat_display), chat_history | |
| # Set up Gradio Chat Interface with conversational format | |
| iface = gr.Interface( | |
| fn=chatbot_interface, | |
| inputs=[ | |
| gr.Textbox(lines=2, placeholder="Type your question here..."), | |
| gr.State([]) # State variable to maintain chat history | |
| ], | |
| outputs=[ | |
| gr.Markdown(), # Display formatted chat history | |
| gr.State() # Update state | |
| ], | |
| title="E-commerce Support Chatbot", | |
| description="Ask questions about order tracking, returns, account help, and more!", | |
| ) | |
| # Launch Gradio interface directly | |
| iface.launch() | |