Spaces:
Sleeping
Sleeping
| import time | |
| import json | |
| import requests | |
| import streamlit as st | |
| import os | |
| from urllib.parse import urlencode, urlparse, parse_qs | |
| st.set_page_config(page_title="ViEduChat - Trợ lý AI giáo dục Việt Nam", page_icon="./app/static/ai.jpg", layout="centered", initial_sidebar_state="collapsed") | |
| # ==== MODULE URL ==== | |
| routing_response_module = st.secrets["ViEduQA_Routing_Module"] | |
| retrieval_module = st.secrets["ViEduQA_Retrieval_Module"] | |
| reranker_module = st.secrets["ViEduQA_Rerank_Module"] | |
| abs_QA_module = st.secrets["ViEduQA_QA_Module"] | |
| url_api_question_classify_model = f"{routing_response_module}/query_classify" | |
| url_api_unrelated_question_response_model = f"{routing_response_module}/response_unrelated_question" | |
| url_api_introduce_system_model = f"{routing_response_module}/about_me" | |
| url_api_retrieval_model = f"{retrieval_module}/search" | |
| url_api_reranker_model = f"{reranker_module}/rerank" | |
| url_api_generation_model = f"{abs_QA_module}/answer" | |
| url_api_extract_reference_model = f"{routing_response_module}/extract_references_unstream" | |
| # ========== STREAMLIT UI ========== | |
| with open("./static/styles.css") as f: | |
| st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True) | |
| if 'messages' not in st.session_state: | |
| st.session_state.messages = [{'role': 'assistant', 'content': "Xin chào. Tôi là trợ lý AI giáo dục Việt Nam được phát triển bởi Đào Thị Ngọc Ánh. Rất vui khi được hỗ trợ bạn trong học tập!"}] | |
| st.markdown(f""" | |
| <div class=logo_area> | |
| <img src="./app/static/ai.jpg"/> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| st.markdown("<h2 style='text-align: center;'>ViEduChat</h2>", unsafe_allow_html=True) | |
| def classify_question(question): | |
| data = { | |
| "question": question | |
| } | |
| response = requests.post(url_api_question_classify_model, json=data) | |
| if response.status_code == 200: | |
| print(response) | |
| return response | |
| else: | |
| return f"Lỗi: {response.status_code} - {response.text}" | |
| def introduce_system(question): | |
| data = { | |
| "question": question | |
| } | |
| response = requests.post(url_api_introduce_system_model, json=data, stream=True) | |
| if response.status_code == 200: | |
| return response | |
| else: | |
| return f"Lỗi: {response.status_code} - {response.text}" | |
| def response_unrelated_question(question): | |
| data = { | |
| "question": question | |
| } | |
| response = requests.post(url_api_unrelated_question_response_model, json=data, stream=True) | |
| if response.status_code == 200: | |
| return response | |
| else: | |
| return f"Lỗi: {response.status_code} - {response.text}" | |
| def retrieve_context(question, top_k=10): | |
| data = { | |
| "query": question, | |
| "top_k": top_k | |
| } | |
| response = requests.post(url_api_retrieval_model, json=data) | |
| if response.status_code == 200: | |
| results = response.json()["results"] | |
| return results | |
| else: | |
| return f"Lỗi tại Retrieval Module: {response.status_code} - {response.text}" | |
| def rerank_context(url_rerank_module, question, relevant_docs, top_k=5): | |
| data = { | |
| "question": question, | |
| "relevant_docs": relevant_docs, | |
| "top_k": top_k | |
| } | |
| response = requests.post(url_rerank_module, json=data) | |
| if response.status_code == 200: | |
| results = response.json()["reranked_docs"] | |
| return results | |
| else: | |
| return f"Lỗi tại Rerank module: {response.status_code} - {response.text}" | |
| def get_abstractive_answer(context, question): | |
| data = { | |
| "context": context, | |
| "question": question | |
| } | |
| response = requests.post(url_api_generation_model, json=data, stream=True) | |
| if response.status_code == 200: | |
| return response | |
| else: | |
| return f"Lỗi: {response.status_code} - {response.text}" | |
| # def get_references(context, question, answer): | |
| # data = { | |
| # "context": context, | |
| # "question": question, | |
| # "answer": answer | |
| # } | |
| # response = requests.post(url_api_extract_reference_model, json=data) | |
| # if response.status_code == 200: | |
| # return response.json()["refs"] | |
| # else: | |
| # return f"Lỗi tại module Reference Extractor: {response.status_code} - {response.text}" | |
| def generate_text_effect(answer): | |
| words = answer.split() | |
| for i in range(len(words)): | |
| time.sleep(0.03) | |
| yield " ".join(words[:i+1]) | |
| for message in st.session_state.messages: | |
| if message['role'] == 'assistant': | |
| avatar_class = "assistant-avatar" | |
| message_class = "assistant-message" | |
| avatar = './app/static/ai.jpg' | |
| else: | |
| avatar_class = "" | |
| message_class = "user-message" | |
| avatar = '' | |
| st.markdown(f""" | |
| <div class="{message_class}"> | |
| <img src="{avatar}" class="{avatar_class}" /> | |
| <div class="stMarkdown">{message['content']}</div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| if prompt := st.chat_input(placeholder='Tôi có thể giúp được gì cho bạn?'): | |
| st.markdown(f""" | |
| <div class="user-message"> | |
| <div class="stMarkdown">{prompt}</div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| st.session_state.messages.append({'role': 'user', 'content': prompt}) | |
| message_placeholder = st.empty() | |
| # full_response = "" | |
| # classify_result = classify_question(question=prompt).json() | |
| # print(f"The type of user query: {classify_result}") | |
| # if classify_result == "EDUCATION_RELATED": | |
| retrieved_context = retrieve_context(question=prompt, top_k=10) | |
| retrieved_context = [item['text'] for item in retrieved_context] | |
| reranked_context = rerank_context(url_rerank_module=url_api_reranker_model, | |
| question=prompt, | |
| relevant_docs=retrieved_context, | |
| top_k=5)[0] | |
| abs_answer = get_abstractive_answer(context=reranked_context, question=prompt) | |
| if isinstance(abs_answer, str): | |
| full_response = abs_answer | |
| message_placeholder.markdown(f""" | |
| <div class="assistant-message"> | |
| <img src="./app/static/ai.jpg" class="assistant-avatar" /> | |
| <div class="stMarkdown">{full_response}</div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| else: | |
| full_response = "" | |
| for line in abs_answer.iter_lines(): | |
| if line: | |
| line = line.decode('utf-8') | |
| if line.startswith('data: '): | |
| data_str = line[6:] | |
| if data_str == '[DONE]': | |
| break | |
| try: | |
| data = json.loads(data_str) | |
| token = data.get('token', '') | |
| full_response += token | |
| message_placeholder.markdown(f""" | |
| <div class="assistant-message"> | |
| <img src="./app/static/ai.jpg" class="assistant-avatar" /> | |
| <div class="stMarkdown">{full_response}●</div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| except json.JSONDecodeError: | |
| pass | |
| # refs = st.expander("Tài liệu tham khảo", expanded=False) | |
| # refs_list = get_references(context=reranked_context, question=prompt, answer=full_response) | |
| # print(refs_list) | |
| # refs.write(f"{refs_list}") | |
| # elif classify_result == "ABOUT_CHATBOT": | |
| # answer = introduce_system(question=prompt) | |
| # if isinstance(answer, str): | |
| # full_response = answer | |
| # message_placeholder.markdown(f""" | |
| # <div class="assistant-message"> | |
| # <img src="./app/static/ai.jpg" class="assistant-avatar" /> | |
| # <div class="stMarkdown">{full_response}</div> | |
| # </div> | |
| # """, unsafe_allow_html=True) | |
| # else: | |
| # full_response = "" | |
| # for line in answer.iter_lines(): | |
| # if line: | |
| # line = line.decode('utf-8') | |
| # if line.startswith('data: '): | |
| # data_str = line[6:] | |
| # if data_str == '[DONE]': | |
| # break | |
| # try: | |
| # data = json.loads(data_str) | |
| # token = data.get('token', '') | |
| # full_response += token | |
| # message_placeholder.markdown(f""" | |
| # <div class="assistant-message"> | |
| # <img src="./app/static/ai.jpg" class="assistant-avatar" /> | |
| # <div class="stMarkdown">{full_response}●</div> | |
| # </div> | |
| # """, unsafe_allow_html=True) | |
| # except json.JSONDecodeError: | |
| # pass | |
| # else: | |
| # answer = response_unrelated_question(question=prompt) | |
| # if isinstance(answer, str): | |
| # full_response = answer | |
| # message_placeholder.markdown(f""" | |
| # <div class="assistant-message"> | |
| # <img src="./app/static/ai.jpg" class="assistant-avatar" /> | |
| # <div class="stMarkdown">{full_response}</div> | |
| # </div> | |
| # """, unsafe_allow_html=True) | |
| # else: | |
| # full_response = "" | |
| # for line in answer.iter_lines(): | |
| # if line: | |
| # line = line.decode('utf-8') | |
| # if line.startswith('data: '): | |
| # data_str = line[6:] | |
| # if data_str == '[DONE]': | |
| # break | |
| # try: | |
| # data = json.loads(data_str) | |
| # token = data.get('token', '') | |
| # full_response += token | |
| # message_placeholder.markdown(f""" | |
| # <div class="assistant-message"> | |
| # <img src="./app/static/ai.jpg" class="assistant-avatar" /> | |
| # <div class="stMarkdown">{full_response}●</div> | |
| # </div> | |
| # """, unsafe_allow_html=True) | |
| # except json.JSONDecodeError: | |
| # pass | |
| message_placeholder.markdown(f""" | |
| <div class="assistant-message"> | |
| <img src="./app/static/ai.jpg" class="assistant-avatar" /> | |
| <div class="stMarkdown"> | |
| {full_response} | |
| </div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| st.session_state.messages.append({'role': 'assistant', 'content': full_response}) |