Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from pymongo import MongoClient | |
| from langchain_mongodb import MongoDBAtlasVectorSearch | |
| from langchain_openai import OpenAIEmbeddings, ChatOpenAI | |
| from dotenv import load_dotenv | |
| from streamlit_option_menu import option_menu | |
| import os | |
| import re | |
| from prompt import PROFESSIONAL_PROMPT | |
| from langchain.chains import ConversationalRetrievalChain | |
| from langchain.memory import ConversationBufferMemory, ConversationBufferWindowMemory | |
| from datetime import datetime, timezone | |
| from handler import is_feedback_message, extract_feedback_content, OPENAI_KEY, collection, store_last_qna, add_question_ticker | |
| # Initialize Embeddings | |
| embeddings = OpenAIEmbeddings( | |
| model="text-embedding-3-small", | |
| openai_api_key=OPENAI_KEY, | |
| dimensions=1536 | |
| ) | |
| def init_memory(): | |
| return ConversationBufferWindowMemory(memory_key="chat_history", k=3, return_messages=True, output_key='answer') | |
| memory_store = {} | |
| def get_user_memory(user_id: str) -> ConversationBufferWindowMemory: | |
| if user_id not in memory_store: | |
| memory_store[user_id] = ConversationBufferWindowMemory( | |
| memory_key="chat_history", | |
| k=3, | |
| return_messages=True, | |
| output_key="answer" | |
| ) | |
| return memory_store[user_id] | |
| # Vector Store Configuration | |
| vector_store = MongoDBAtlasVectorSearch( | |
| collection=collection, | |
| embedding=embeddings, | |
| index_name='vector_index', | |
| text_key="text" | |
| ) | |
| # Model Configuration | |
| llm = ChatOpenAI( | |
| model_name="gpt-4", | |
| openai_api_key=OPENAI_KEY, | |
| temperature=0, | |
| max_tokens=800, | |
| ) | |
| # @st.cache_resource | |
| qa_chains = {} | |
| def create_conversational_chain(user_id: None): | |
| memory = init_memory() | |
| print(f"π§ Total memory messages: {len(memory.chat_memory.messages)}") | |
| retriever = vector_store.as_retriever( | |
| search_type="similarity", | |
| search_kwargs={"k": 3, "score_threshold": 0.8} | |
| ) | |
| qa_chains[user_id] = ConversationalRetrievalChain.from_llm( | |
| llm=llm, | |
| retriever=retriever, | |
| memory=memory, | |
| combine_docs_chain_kwargs={"prompt": PROFESSIONAL_PROMPT}, | |
| return_source_documents=True | |
| ) | |
| return qa_chains[user_id] | |
| def clean_answer(raw_answer): | |
| formatted = re.sub(r'(\d+\.)\s', r'\n\1 ', raw_answer) | |
| cleaned = re.sub(r'[*_]{2}', '', formatted) | |
| reference_checked = re.sub(r'^.*Read more at(?!.*(https?://|www\.)).*$', '', cleaned, flags=re.MULTILINE) | |
| return reference_checked.strip() | |
| def ask(query, user_id="anonymous"): | |
| try: | |
| qa = create_conversational_chain("anonymous") | |
| add_question_ticker(query) | |
| result = qa({"question": query}) | |
| usage = result.get('__raw', {}).get('usage', {}) | |
| print(f"ππππππ Tokens used: {usage}") | |
| if not result['source_documents']: | |
| return f"Sorry, no relevant information found on the question asked. Please contact immigration customer service through https://www.imigrasi.go.id/hubungi." | |
| else: | |
| answer = clean_answer(result["answer"]) | |
| store_last_qna(user_id, query, answer) | |
| return answer | |
| except Exception as e: | |
| error_msg = f""" | |
| <div class="assistant-message"> | |
| β οΈ An error occured<br> | |
| An error occured, please try again or contact us:<br> | |
| β’ Email: galuh.adika@gmail.com<br> | |
| error: {str(e)}<br> | |
| </div> | |
| """ | |
| return error_msg | |