Spaces:
Sleeping
Sleeping
| import os | |
| import streamlit as st | |
| import torch | |
| from transformers import AutoTokenizer, AutoModel | |
| import numpy as np | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| import json | |
| def load_model_and_tokenizer(model_name='intfloat/multilingual-e5-small'): | |
| """ | |
| Cached function to load model and tokenizer | |
| This ensures the model is loaded only once and reused | |
| """ | |
| print("Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| print("Loading model...") | |
| model = AutoModel.from_pretrained(model_name, torch_dtype=torch.float16) | |
| return tokenizer, model | |
| class VietnameseChatbot: | |
| def __init__(self, model_name='intfloat/multilingual-e5-small'): | |
| """ | |
| Initialize the Vietnamese chatbot with pre-loaded model and conversation data | |
| """ | |
| # Load pre-trained model and tokenizer using cached function | |
| self.tokenizer, self.model = load_model_and_tokenizer(model_name) | |
| # Load comprehensive conversation dataset | |
| self.conversation_data = self._load_conversation_data() | |
| # Pre-compute embeddings for faster response generation | |
| print("Pre-computing conversation embeddings...") | |
| self.conversation_embeddings = self._compute_embeddings() | |
| def _load_conversation_data(self): | |
| """ | |
| Load a comprehensive conversation dataset | |
| """ | |
| return [ | |
| # Greeting conversations | |
| {"query": "Xin chào", "response": "Chào bạn! Tôi có thể giúp gì cho bạn?"}, | |
| {"query": "Hi", "response": "Xin chào! Tôi là trợ lý AI tiếng Việt."}, | |
| {"query": "Chào buổi sáng", "response": "Chào buổi sáng! Chúc bạn một ngày tốt lành."}, | |
| # Identity and purpose | |
| {"query": "Bạn là ai?", "response": "Tôi là trợ lý AI được phát triển để hỗ trợ và trò chuyện bằng tiếng Việt."}, | |
| {"query": "Bạn từ đâu đến?", "response": "Tôi được phát triển bởi một nhóm kỹ sư AI, và tôn chỉ của tôi là hỗ trợ con người."}, | |
| # Small talk | |
| {"query": "Bạn thích gì?", "response": "Tôi thích học hỏi và giúp đỡ mọi người. Mỗi cuộc trò chuyện là một cơ hội để tôi phát triển."}, | |
| {"query": "Bạn có thể làm gì?", "response": "Tôi có thể trò chuyện, trả lời câu hỏi, và hỗ trợ bạn trong nhiều tình huống khác nhau."}, | |
| # Weather and time | |
| {"query": "Thời tiết hôm nay thế nào?", "response": "Xin lỗi, tôi không thể cung cấp thông tin thời tiết trực tiếp. Bạn có thể kiểm tra ứng dụng dự báo thời tiết."}, | |
| {"query": "Bây giờ là mấy giờ?", "response": "Tôi là trợ lý AI, nên không thể xem đồng hồ. Bạn có thể kiểm tra thiết bị của mình."}, | |
| # Assistance offers | |
| {"query": "Tôi cần trợ giúp", "response": "Tôi sẵn sàng hỗ trợ bạn. Bạn cần giúp gì?"}, | |
| {"query": "Giúp tôi với cái gì đó", "response": "Vâng, tôi có thể hỗ trợ bạn. Hãy cho tôi biết chi tiết hơn."}, | |
| # Farewell | |
| {"query": "Tạm biệt", "response": "Hẹn gặp lại! Chúc bạn một ngày tốt đẹp."}, | |
| {"query": "Bye", "response": "Tạm biệt! Rất vui được trò chuyện với bạn."}, | |
| ] | |
| def _compute_embeddings(_self): # Add underscore to self parameter | |
| """ | |
| Pre-compute embeddings for conversation queries | |
| Cached to avoid recomputing on every run | |
| """ | |
| def embed_single_text(text, tokenizer, model): | |
| try: | |
| # Tokenize and generate embeddings | |
| inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True) | |
| with torch.no_grad(): | |
| model_output = model(**inputs) | |
| # Mean pooling | |
| token_embeddings = model_output[0] | |
| input_mask_expanded = inputs['attention_mask'].unsqueeze(-1).expand(token_embeddings.size()).float() | |
| embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |
| return embeddings.numpy()[0] | |
| except Exception as e: | |
| print(f"Embedding error: {e}") | |
| return None | |
| embeddings = [] | |
| for conversation in _self.conversation_data: # Use _self instead of self | |
| embedding = embed_single_text(conversation['query'], _self.tokenizer, _self.model) # Use _self instead of self | |
| if embedding is not None: | |
| embeddings.append(embedding) | |
| return np.array(embeddings) | |
| def embed_text(self, text): | |
| """ | |
| Generate embeddings for input text | |
| """ | |
| try: | |
| # Tokenize and generate embeddings | |
| inputs = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True) | |
| with torch.no_grad(): | |
| model_output = self.model(**inputs) | |
| # Mean pooling | |
| embeddings = self.mean_pooling(model_output, inputs['attention_mask']) | |
| return embeddings.numpy() | |
| except Exception as e: | |
| print(f"Embedding error: {e}") | |
| return None | |
| def mean_pooling(self, model_output, attention_mask): | |
| """ | |
| Perform mean pooling on model output | |
| """ | |
| token_embeddings = model_output[0] | |
| input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
| return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |
| def get_response(self, user_query): | |
| """ | |
| Find the most similar response from conversation data | |
| """ | |
| try: | |
| # Embed user query | |
| query_embedding = self.embed_text(user_query) | |
| if query_embedding is None: | |
| return "Xin lỗi, đã có lỗi xảy ra khi phân tích câu hỏi của bạn." | |
| # Calculate cosine similarities | |
| similarities = cosine_similarity(query_embedding, self.conversation_embeddings)[0] | |
| # Find most similar response | |
| best_match_index = np.argmax(similarities) | |
| # Return response if similarity is above threshold | |
| if similarities[best_match_index] > 0.5: | |
| return self.conversation_data[best_match_index]['response'] | |
| return "Xin lỗi, tôi chưa hiểu rõ câu hỏi của bạn. Bạn có thể diễn đạt lại được không?" | |
| except Exception as e: | |
| print(f"Response generation error: {e}") | |
| return "Đã xảy ra lỗi. Xin vui lòng thử lại." | |
| def initialize_chatbot(): | |
| """ | |
| Cached function to initialize the chatbot | |
| This ensures the chatbot is created only once | |
| """ | |
| return VietnameseChatbot() | |
| def main(): | |
| st.set_page_config( | |
| page_title="Trợ Lý AI Tiếng Việt", | |
| page_icon="🤖", | |
| ) | |
| st.title("🤖 Trợ Lý AI Tiếng Việt") | |
| st.caption("Trò chuyện với trợ lý AI được phát triển bằng mô hình đa ngôn ngữ") | |
| # Initialize chatbot using cached initialization | |
| chatbot = initialize_chatbot() | |
| # Chat history in session state | |
| if 'messages' not in st.session_state: | |
| st.session_state.messages = [] | |
| # Sidebar for additional information | |
| with st.sidebar: | |
| st.header("Về Trợ Lý AI") | |
| st.write("Đây là một trợ lý AI được phát triển để hỗ trợ trò chuyện bằng tiếng Việt.") | |
| st.write("Mô hình sử dụng: intfloat/multilingual-e5-small") | |
| # Display chat messages | |
| for message in st.session_state.messages: | |
| with st.chat_message(message["role"]): | |
| st.markdown(message["content"]) | |
| # User input | |
| if prompt := st.chat_input("Hãy nói gì đó..."): | |
| # Add user message to chat history | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| # Display user message | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| # Get chatbot response | |
| response = chatbot.get_response(prompt) | |
| # Display chatbot response | |
| with st.chat_message("assistant"): | |
| st.markdown(response) | |
| # Add assistant message to chat history | |
| st.session_state.messages.append({"role": "assistant", "content": response}) | |
| if __name__ == "__main__": | |
| main() |