Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import requests | |
| from bertopic import BERTopic | |
| from sentence_transformers import SentenceTransformer | |
| import numpy as np | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| import pandas as pd | |
| import plotly.graph_objects as go | |
| from datetime import datetime | |
| import json | |
| from collections import deque | |
| from datasets import load_dataset | |
| class BERTopicChatbot: | |
| #Initialize chatbot with a Hugging Face dataset | |
| #dataset_name: name of the dataset on Hugging Face (e.g., 'vietnam/legal') | |
| #text_column: name of the column containing the text data | |
| #split: which split of the dataset to use ('train', 'test', 'validation') | |
| #max_samples: maximum number of samples to use (to manage memory) | |
| def __init__(self, dataset_name, text_column, split="train", max_samples=10000): | |
| # Initialize BERT sentence transformer | |
| self.sentence_model = SentenceTransformer('all-MiniLM-L6-v2') | |
| # Load dataset from Hugging Face | |
| try: | |
| dataset = load_dataset(dataset_name, split=split) | |
| # Convert to pandas DataFrame and sample if necessary | |
| if len(dataset) > max_samples: | |
| dataset = dataset.shuffle(seed=42).select(range(max_samples)) | |
| self.df = dataset.to_pandas() | |
| # Ensure text column exists | |
| if text_column not in self.df.columns: | |
| raise ValueError(f"Column '{text_column}' not found in dataset. Available columns: {self.df.columns}") | |
| self.documents = self.df[text_column].tolist() | |
| # Create and train BERTopic model | |
| self.topic_model = BERTopic(embedding_model=self.sentence_model) | |
| self.topics, self.probs = self.topic_model.fit_transform(self.documents) | |
| # Create document embeddings for similarity search | |
| self.doc_embeddings = self.sentence_model.encode(self.documents) | |
| # Initialize metrics storage | |
| self.metrics_history = { | |
| 'similarities': deque(maxlen=100), | |
| 'response_times': deque(maxlen=100), | |
| 'token_counts': deque(maxlen=100), | |
| 'topics_accessed': {} | |
| } | |
| # Store dataset info | |
| self.dataset_info = { | |
| 'name': dataset_name, | |
| 'split': split, | |
| 'total_documents': len(self.documents), | |
| 'topics_found': len(set(self.topics)) | |
| } | |
| except Exception as e: | |
| st.error(f"Error loading dataset: {str(e)}") | |
| raise | |
| def get_metrics_visualizations(self): | |
| """Generate visualizations for chatbot metrics""" | |
| # Similarity trend | |
| fig_similarity = go.Figure() | |
| fig_similarity.add_trace(go.Scatter( | |
| y=list(self.metrics_history['similarities']), | |
| mode='lines+markers', | |
| name='Similarity Score' | |
| )) | |
| fig_similarity.update_layout( | |
| title='Response Similarity Trend', | |
| yaxis_title='Similarity Score', | |
| xaxis_title='Query Number' | |
| ) | |
| # Response time trend | |
| fig_response_time = go.Figure() | |
| fig_response_time.add_trace(go.Scatter( | |
| y=list(self.metrics_history['response_times']), | |
| mode='lines+markers', | |
| name='Response Time' | |
| )) | |
| fig_response_time.update_layout( | |
| title='Response Time Trend', | |
| yaxis_title='Time (seconds)', | |
| xaxis_title='Query Number' | |
| ) | |
| # Token usage trend | |
| fig_tokens = go.Figure() | |
| fig_tokens.add_trace(go.Scatter( | |
| y=list(self.metrics_history['token_counts']), | |
| mode='lines+markers', | |
| name='Token Count' | |
| )) | |
| fig_tokens.update_layout( | |
| title='Token Usage Trend', | |
| yaxis_title='Number of Tokens', | |
| xaxis_title='Query Number' | |
| ) | |
| # Topics accessed pie chart | |
| labels = list(self.metrics_history['topics_accessed'].keys()) | |
| values = list(self.metrics_history['topics_accessed'].values()) | |
| fig_topics = go.Figure(data=[go.Pie(labels=labels, values=values)]) | |
| fig_topics.update_layout(title='Topics Accessed Distribution') | |
| # Make all figures responsive | |
| for fig in [fig_similarity, fig_response_time, fig_tokens, fig_topics]: | |
| fig.update_layout( | |
| autosize=True, | |
| margin=dict(l=20, r=20, t=40, b=20), | |
| height=300 | |
| ) | |
| return fig_similarity, fig_response_time, fig_tokens, fig_topics | |
| def get_most_similar_document(self, query, top_k=3): | |
| # Encode the query | |
| query_embedding = self.sentence_model.encode([query])[0] | |
| # Calculate similarities | |
| similarities = cosine_similarity([query_embedding], self.doc_embeddings)[0] | |
| # Get top k most similar documents | |
| top_indices = similarities.argsort()[-top_k:][::-1] | |
| return [self.documents[i] for i in top_indices], similarities[top_indices] | |
| def get_response(self, user_query): | |
| try: | |
| start_time = datetime.now() | |
| # Get most similar documents | |
| similar_docs, similarities = self.get_most_similar_document(user_query) | |
| # Get topic for the query | |
| query_topic, _ = self.topic_model.transform([user_query]) | |
| # Track topic access | |
| topic_id = str(query_topic[0]) | |
| self.metrics_history['topics_accessed'][topic_id] = \ | |
| self.metrics_history['topics_accessed'].get(topic_id, 0) + 1 | |
| # If similarity is too low, return a default response | |
| if max(similarities) < 0.5: | |
| response = "Xin lỗi, tôi không có đủ thông tin để trả lời câu hỏi này một cách chính xác." | |
| else: | |
| response = similar_docs[0] | |
| # Track metrics | |
| end_time = datetime.now() | |
| self.metrics_history['similarities'].append(float(max(similarities))) | |
| self.metrics_history['response_times'].append((end_time - start_time).total_seconds()) | |
| self.metrics_history['token_counts'].append(len(response.split())) | |
| metrics = { | |
| 'similarity': float(max(similarities)), | |
| 'response_time': (end_time - start_time).total_seconds(), | |
| 'tokens': len(response.split()), | |
| 'topic': topic_id | |
| } | |
| return response, metrics | |
| except Exception as e: | |
| return f"Error processing query: {str(e)}", {'error': str(e)} | |
| def get_dataset_info(self): | |
| #Return information about the loaded dataset and metrics | |
| try: | |
| return { | |
| 'dataset_info': self.dataset_info, | |
| 'metrics': { | |
| 'avg_similarity': np.mean(list(self.metrics_history['similarities'])) if self.metrics_history['similarities'] else 0, | |
| 'avg_response_time': np.mean(list(self.metrics_history['response_times'])) if self.metrics_history['response_times'] else 0, | |
| 'total_tokens': sum(self.metrics_history['token_counts']), | |
| 'topics_accessed': self.metrics_history['topics_accessed'] | |
| } | |
| } | |
| except Exception as e: | |
| return { | |
| 'error': str(e), | |
| 'dataset_info': None, | |
| 'metrics': None | |
| } | |
| def initialize_chatbot(dataset_name, text_column, split="train", max_samples=10000): | |
| return BERTopicChatbot(dataset_name, text_column, split, max_samples) | |
| def main(): | |
| st.title("🤖 Trợ Lý AI - BERTopic") | |
| st.caption("Trò chuyện với chúng mình nhé!") | |
| # Dataset selection sidebar | |
| with st.sidebar: | |
| st.header("Dataset Configuration") | |
| dataset_name = st.text_input( | |
| "Hugging Face Dataset Name", | |
| value="Kanakmi/mental-disorders", | |
| help="Enter the name of a dataset from Hugging Face (e.g., 'Kanakmi/mental-disorders')" | |
| ) | |
| text_column = st.text_input( | |
| "Text Column Name", | |
| value="text", | |
| help="Enter the name of the column containing the text data" | |
| ) | |
| split = st.selectbox( | |
| "Dataset Split", | |
| options=["train", "test", "validation"], | |
| index=0 | |
| ) | |
| max_samples = st.number_input( | |
| "Maximum Samples", | |
| min_value=100, | |
| max_value=100000, | |
| value=10000, | |
| step=1000, | |
| help="Maximum number of samples to load from the dataset" | |
| ) | |
| if st.button("Load Dataset"): | |
| with st.spinner("Loading dataset and initializing model..."): | |
| try: | |
| st.session_state.chatbot = initialize_chatbot( | |
| dataset_name, text_column, split, max_samples | |
| ) | |
| st.success("Dataset loaded successfully!") | |
| except Exception as e: | |
| st.error(f"Error loading dataset: {str(e)}") | |
| # Initialize session state variables if they don't exist | |
| if 'chatbot' not in st.session_state: | |
| st.session_state.chatbot = None | |
| if 'messages' not in st.session_state: | |
| st.session_state.messages = [] | |
| # Create tabs for chat and metrics | |
| chat_tab, metrics_tab = st.tabs(["Chat", "Metrics"]) | |
| with chat_tab: | |
| # Display existing messages | |
| for message in st.session_state.messages: | |
| with st.chat_message(message["role"]): | |
| st.markdown(message["content"]) | |
| # Only show chat input if chatbot is initialized | |
| if st.session_state.chatbot is not None: | |
| if prompt := st.chat_input("Hãy nói gì đó..."): | |
| # Add user message | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| # Get chatbot response | |
| response, metrics = st.session_state.chatbot.get_response(prompt) | |
| # Add assistant response | |
| with st.chat_message("assistant"): | |
| st.markdown(response) | |
| with st.expander("Response Metrics"): | |
| st.json(metrics) | |
| st.session_state.messages.append({"role": "assistant", "content": response}) | |
| else: | |
| st.info("Please load a dataset first to start chatting.") | |
| with metrics_tab: | |
| if st.session_state.chatbot is not None: | |
| try: | |
| # Get visualizations from session state chatbot | |
| fig_similarity, fig_response_time, fig_tokens, fig_topics = st.session_state.chatbot.get_metrics_visualizations() | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.plotly_chart(fig_similarity, use_container_width=True) | |
| st.plotly_chart(fig_tokens, use_container_width=True) | |
| with col2: | |
| st.plotly_chart(fig_response_time, use_container_width=True) | |
| st.plotly_chart(fig_topics, use_container_width=True) | |
| # Display statistics | |
| st.subheader("Overall Statistics") | |
| metrics_history = st.session_state.chatbot.metrics_history | |
| if len(metrics_history['similarities']) > 0: | |
| stats_col1, stats_col2, stats_col3 = st.columns(3) | |
| with stats_col1: | |
| st.metric("Avg Similarity", | |
| f"{np.mean(list(metrics_history['similarities'])):.3f}") | |
| with stats_col2: | |
| st.metric("Avg Response Time", | |
| f"{np.mean(list(metrics_history['response_times'])):.3f}s") | |
| with stats_col3: | |
| st.metric("Total Tokens Used", | |
| sum(metrics_history['token_counts'])) | |
| else: | |
| st.info("No chat history available yet. Start a conversation to see metrics.") | |
| except Exception as e: | |
| st.error(f"Error displaying metrics: {str(e)}") | |
| else: | |
| st.info("Please load a dataset first to view metrics.") | |
| if __name__ == "__main__": | |
| main() |