JustKiddo commited on
Commit
14235ee
·
verified ·
1 Parent(s): b7ffac5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +314 -0
app.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import requests
3
+ from bertopic import BERTopic
4
+ from sentence_transformers import SentenceTransformer
5
+ import numpy as np
6
+ from sklearn.metrics.pairwise import cosine_similarity
7
+ import pandas as pd
8
+ import plotly.graph_objects as go
9
+ from datetime import datetime
10
+ import json
11
+ from collections import deque
12
+ from datasets import load_dataset
13
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
14
+ import torch # Import torch
15
+
16
+ class BERTopicChatbot:
17
+
18
+ def __init__(self, dataset_name, text_column, split="train", max_samples=10000):
19
+ # Initialize BERT sentence transformer
20
+ self.sentence_model = SentenceTransformer('all-MiniLM-L6-v2')
21
+
22
+ #Initialize BARTpho model and tokenizer
23
+ self.bartpho_model_name = "vinai/bartpho-syllable"
24
+
25
+ # Load tokenizer only once
26
+ self.tokenizer = AutoTokenizer.from_pretrained(self.bartpho_model_name)
27
+
28
+ # Load Dataset and set other variables
29
+ try:
30
+ dataset = load_dataset(dataset_name, split=split)
31
+ # Convert to pandas DataFrame and sample if necessary
32
+ if len(dataset) > max_samples:
33
+ dataset = dataset.shuffle(seed=42).select(range(max_samples))
34
+
35
+ self.df = dataset.to_pandas()
36
+
37
+ # Ensure text column exists
38
+ if text_column not in self.df.columns:
39
+ raise ValueError(f"Column '{text_column}' not found in dataset. Available columns: {self.df.columns}")
40
+
41
+ self.documents = self.df[text_column].tolist()
42
+
43
+ # Create and train BERTopic model
44
+ self.topic_model = BERTopic(embedding_model=self.sentence_model)
45
+ self.topics, self.probs = self.topic_model.fit_transform(self.documents)
46
+
47
+ # Create document embeddings for similarity search
48
+ self.doc_embeddings = self.sentence_model.encode(self.documents)
49
+
50
+ # Initialize metrics storage
51
+ self.metrics_history = {
52
+ 'similarities': deque(maxlen=100),
53
+ 'response_times': deque(maxlen=100),
54
+ 'token_counts': deque(maxlen=100),
55
+ 'topics_accessed': {}
56
+ }
57
+
58
+ # Store dataset info
59
+ self.dataset_info = {
60
+ 'name': dataset_name,
61
+ 'split': split,
62
+ 'total_documents': len(self.documents),
63
+ 'topics_found': len(set(self.topics))
64
+ }
65
+ except Exception as e:
66
+ st.error(f"Error loading dataset: {str(e)}")
67
+ raise
68
+
69
+ #Load fine-tuned BARTpho model
70
+ self.bartpho_model = AutoModelForSeq2SeqLM.from_pretrained("./bartpho_chatbot").to("cuda" if torch.cuda.is_available() else "cpu")
71
+ self.bartpho_model.eval()
72
+
73
+ def get_metrics_visualizations(self):
74
+ """Generate visualizations for chatbot metrics"""
75
+ # Similarity trend
76
+ fig_similarity = go.Figure()
77
+ fig_similarity.add_trace(go.Scatter(
78
+ y=list(self.metrics_history['similarities']),
79
+ mode='lines+markers',
80
+ name='Similarity Score'
81
+ ))
82
+ fig_similarity.update_layout(
83
+ title='Response Similarity Trend',
84
+ yaxis_title='Similarity Score',
85
+ xaxis_title='Query Number'
86
+ )
87
+
88
+ # Response time trend
89
+ fig_response_time = go.Figure()
90
+ fig_response_time.add_trace(go.Scatter(
91
+ y=list(self.metrics_history['response_times']),
92
+ mode='lines+markers',
93
+ name='Response Time'
94
+ ))
95
+ fig_response_time.update_layout(
96
+ title='Response Time Trend',
97
+ yaxis_title='Time (seconds)',
98
+ xaxis_title='Query Number'
99
+ )
100
+
101
+ # Token usage trend
102
+ fig_tokens = go.Figure()
103
+ fig_tokens.add_trace(go.Scatter(
104
+ y=list(self.metrics_history['token_counts']),
105
+ mode='lines+markers',
106
+ name='Token Count'
107
+ ))
108
+ fig_tokens.update_layout(
109
+ title='Token Usage Trend',
110
+ yaxis_title='Number of Tokens',
111
+ xaxis_title='Query Number'
112
+ )
113
+
114
+ # Topics accessed pie chart
115
+ labels = list(self.metrics_history['topics_accessed'].keys())
116
+ values = list(self.metrics_history['topics_accessed'].values())
117
+ fig_topics = go.Figure(data=[go.Pie(labels=labels, values=values)])
118
+ fig_topics.update_layout(title='Topics Accessed Distribution')
119
+
120
+ # Make all figures responsive
121
+ for fig in [fig_similarity, fig_response_time, fig_tokens, fig_topics]:
122
+ fig.update_layout(
123
+ autosize=True,
124
+ margin=dict(l=20, r=20, t=40, b=20),
125
+ height=300
126
+ )
127
+
128
+ return fig_similarity, fig_response_time, fig_tokens, fig_topics
129
+
130
+ def get_most_similar_document(self, query, top_k=3):
131
+ # Encode the query
132
+ query_embedding = self.sentence_model.encode([query])[0]
133
+
134
+ # Calculate similarities
135
+ similarities = cosine_similarity([query_embedding], self.doc_embeddings)[0]
136
+
137
+ # Get top k most similar documents
138
+ top_indices = similarities.argsort()[-top_k:][::-1]
139
+
140
+ return [self.documents[i] for i in top_indices], similarities[top_indices]
141
+
142
+ def get_response(self, user_query):
143
+ try:
144
+ start_time = datetime.now()
145
+
146
+ # Generate response with BARTpho
147
+ input_ids = self.tokenizer(user_query, return_tensors="pt").input_ids.to(self.bartpho_model.device) #Send the tensor to the same device as the model.
148
+
149
+ with torch.no_grad():
150
+ outputs = self.bartpho_model.generate(input_ids, max_length=100, num_beams=5, early_stopping=True) # Tune max_length, num_beams
151
+
152
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
153
+
154
+ end_time = datetime.now()
155
+ metrics = {
156
+ 'similarity': 0.0, # Remove original implementation
157
+ 'response_time': (end_time - start_time).total_seconds(),
158
+ 'tokens': len(response.split()),
159
+ 'topic': "N/A", # Remove original implementation
160
+ 'detected_condition': "N/A" # Remove original implementation
161
+ }
162
+
163
+ # Update metrics history
164
+ self.metrics_history['similarities'].append(metrics['similarity'])
165
+ self.metrics_history['response_times'].append(metrics['response_time'])
166
+ self.metrics_history['token_counts'].append(metrics['tokens'])
167
+ topic_id = "N/A" # Remove original implementation
168
+ self.metrics_history['topics_accessed'][topic_id] = \
169
+ self.metrics_history['topics_accessed'].get(topic_id, 0) + 1
170
+
171
+ return response, metrics
172
+
173
+ except Exception as e:
174
+ return f"Error processing query: {str(e)}", {'error': str(e)}
175
+
176
+ def get_dataset_info(self):
177
+ #Return information about the loaded dataset and metrics
178
+ try:
179
+ return {
180
+ 'dataset_info': self.dataset_info,
181
+ 'metrics': {
182
+ 'avg_similarity': np.mean(list(self.metrics_history['similarities'])) if self.metrics_history['similarities'] else 0,
183
+ 'avg_response_time': np.mean(list(self.metrics_history['response_times'])) if self.metrics_history['response_times'] else 0,
184
+ 'total_tokens': sum(self.metrics_history['token_counts']),
185
+ 'topics_accessed': self.metrics_history['topics_accessed']
186
+ }
187
+ }
188
+ except Exception as e:
189
+ return {
190
+ 'error': str(e),
191
+ 'dataset_info': None,
192
+ 'metrics': None
193
+ }
194
+
195
+ @st.cache_resource
196
+ def initialize_chatbot(dataset_name, text_column, split="train", max_samples=10000):
197
+ return BERTopicChatbot(dataset_name, text_column, split, max_samples)
198
+
199
+ def main():
200
+ st.title("🤖 Trợ Lý AI - BERTopic")
201
+ st.caption("Trò chuyện với chúng mình nhé!")
202
+
203
+ # Dataset selection sidebar
204
+ with st.sidebar:
205
+ st.header("Dataset Configuration")
206
+ dataset_name = st.text_input(
207
+ "Hugging Face Dataset Name",
208
+ value="Kanakmi/mental-disorders",
209
+ help="Enter the name of a dataset from Hugging Face (e.g., 'Kanakmi/mental-disorders')"
210
+ )
211
+ text_column = st.text_input(
212
+ "Text Column Name",
213
+ value="text",
214
+ help="Enter the name of the column containing the text data"
215
+ )
216
+ split = st.selectbox(
217
+ "Dataset Split",
218
+ options=["train", "test", "val", "validation"],
219
+ index=0
220
+ )
221
+ max_samples = st.number_input(
222
+ "Maximum Samples",
223
+ min_value=100,
224
+ max_value=100000,
225
+ value=10000,
226
+ step=1000,
227
+ help="Maximum number of samples to load from the dataset"
228
+ )
229
+
230
+ if st.button("Load Dataset"):
231
+ with st.spinner("Loading dataset and initializing model..."):
232
+ try:
233
+ st.session_state.chatbot = initialize_chatbot(
234
+ dataset_name, text_column, split, max_samples
235
+ )
236
+ st.success("Dataset loaded successfully!")
237
+ except Exception as e:
238
+ st.error(f"Error loading dataset: {str(e)}")
239
+
240
+ # Initialize session state variables if they don't exist
241
+ if 'chatbot' not in st.session_state:
242
+ st.session_state.chatbot = None
243
+
244
+ if 'messages' not in st.session_state:
245
+ st.session_state.messages = []
246
+
247
+ # Create tabs for chat and metrics
248
+ chat_tab, metrics_tab = st.tabs(["Chat", "Metrics"])
249
+
250
+ with chat_tab:
251
+ # Display existing messages
252
+ for message in st.session_state.messages:
253
+ with st.chat_message(message["role"]):
254
+ st.markdown(message["content"])
255
+
256
+ # Only show chat input if chatbot is initialized
257
+ if st.session_state.chatbot is not None:
258
+ if prompt := st.chat_input("Hãy nói gì đó..."):
259
+ # Add user message
260
+ st.session_state.messages.append({"role": "user", "content": prompt})
261
+ with st.chat_message("user"):
262
+ st.markdown(prompt)
263
+
264
+ # Get chatbot response
265
+ response, metrics = st.session_state.chatbot.get_response(prompt)
266
+
267
+ # Add assistant response
268
+ with st.chat_message("assistant"):
269
+ st.markdown(response)
270
+ with st.expander("Response Metrics"):
271
+ st.json(metrics)
272
+
273
+ st.session_state.messages.append({"role": "assistant", "content": response})
274
+ else:
275
+ st.info("Please load a dataset first to start chatting.")
276
+
277
+ with metrics_tab:
278
+ if st.session_state.chatbot is not None:
279
+ try:
280
+ # Get visualizations from session state chatbot
281
+ fig_similarity, fig_response_time, fig_tokens, fig_topics = st.session_state.chatbot.get_metrics_visualizations()
282
+
283
+ col1, col2 = st.columns(2)
284
+ with col1:
285
+ st.plotly_chart(fig_similarity, use_container_width=True)
286
+ st.plotly_chart(fig_tokens, use_container_width=True)
287
+
288
+ with col2:
289
+ st.plotly_chart(fig_response_time, use_container_width=True)
290
+ st.plotly_chart(fig_topics, use_container_width=True)
291
+
292
+ # Display statistics
293
+ st.subheader("Overall Statistics")
294
+ metrics_history = st.session_state.chatbot.metrics_history
295
+ if len(metrics_history['similarities']) > 0:
296
+ stats_col1, stats_col2, stats_col3 = st.columns(3)
297
+ with stats_col1:
298
+ st.metric("Avg Similarity",
299
+ f"{np.mean(list(metrics_history['similarities'])):.3f}")
300
+ with stats_col2:
301
+ st.metric("Avg Response Time",
302
+ f"{np.mean(list(metrics_history['response_times'])):.3f}s")
303
+ with stats_col3:
304
+ st.metric("Total Tokens Used",
305
+ sum(metrics_history['token_counts']))
306
+ else:
307
+ st.info("No chat history available yet. Start a conversation to see metrics.")
308
+ except Exception as e:
309
+ st.error(f"Error displaying metrics: {str(e)}")
310
+ else:
311
+ st.info("Please load a dataset first to view metrics.")
312
+
313
+ if __name__ == "__main__":
314
+ main()