Spaces:
Paused
Paused
| import gradio as gr | |
| from langchain_mistralai.chat_models import ChatMistralAI | |
| from langchain_community.vectorstores import Chroma | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain.memory import ConversationBufferMemory | |
| from langchain.prompts import PromptTemplate | |
| from langchain.chains import RetrievalQA | |
| from mistralai.client import MistralClient | |
| from langchain.chains import ConversationalRetrievalChain | |
| from langchain.prompts import (ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder,) | |
| from langchain.schema import SystemMessage | |
| # from mistralai.models.chat_completion import ChatMessage | |
| import pprint | |
| import requests | |
| import transformers | |
| import torch | |
| import tqdm | |
| import accelerate | |
| import glob | |
| import ast # Used to convert string representation of list to an actual list | |
| from transformers import MBartForConditionalGeneration, MBart50TokenizerFast | |
| translation_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt") | |
| translation_tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt") | |
| import pandas as pd | |
| languages_list = [("Gujarati", "gu_IN"), ('Hindi',"hi_IN") , ("Bengali","bn_IN"), ("Malayalam","ml_IN"), | |
| ("Marathi","mr_IN"), ("Tamil","ta_IN"), ("Telugu","te_IN")] | |
| lang_global = '' | |
| def intitalize_lang(language): | |
| global lang_global | |
| lang_global = language | |
| print("intitalize_lang"+lang_global) | |
| def english_to_indian(sentence): | |
| #print ("english_to_indian"+lang_global) | |
| translated_sentence = '' | |
| translation_tokenizer.src_lang = "en_xx" | |
| chunks = [sentence[i:i+500] for i in range(0, len(sentence), 500)] | |
| for chunk in chunks: | |
| encoded_hi = translation_tokenizer(chunk, return_tensors="pt") | |
| generated_tokens = translation_model.generate(**encoded_hi, | |
| forced_bos_token_id=translation_tokenizer.lang_code_to_id[lang_global] ) | |
| x = translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) | |
| translated_sentence = translated_sentence + x[0] | |
| print(translated_sentence) | |
| return translated_sentence | |
| def indian_to_english(sentence): | |
| translated_sentence = '' | |
| translation_tokenizer.src_lang = lang_global | |
| chunks = [sentence[i:i+500] for i in range(0, len(sentence), 500)] | |
| for chunk in chunks: | |
| encoded_hi = translation_tokenizer(chunk, return_tensors="pt") | |
| generated_tokens = translation_model.generate(**encoded_hi, forced_bos_token_id=translation_tokenizer.lang_code_to_id["en_XX"] ) | |
| x = translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) | |
| translated_sentence = translated_sentence + x[0] | |
| print(translated_sentence) | |
| return translated_sentence | |
| #Mistral inference endpoints loaded from Azure | |
| llm = ChatMistralAI( | |
| endpoint="https://Mistral-large-vimte-serverless.eastus2.inference.ai.azure.com" + "/v1", | |
| mistral_api_key="BalqnPBWYmAmh5WHBxY8ihqsgsZTr5ev", | |
| temperature = 0.1, | |
| max_tokens = 1000, | |
| top_k = 3 | |
| ) | |
| #We create embeddings before the run. During run, the code loads up it from local storage. | |
| embedding = HuggingFaceEmbeddings() | |
| chat_history=[] | |
| store = Chroma( | |
| embedding_function=embedding, | |
| # ids = [f"{item.metadata['source']}-{index}" for index, item in enumerate(data)], | |
| collection_name="wep", | |
| persist_directory="/home/user/app/chroma_semantic/", | |
| ) | |
| template = """You are a chatbot designed to help the Indian users of wep.gov.in website which is a platform which aims to help resolve the information asymmetry that exists in the ecosystem for women entrepreneurs. | |
| Do not answer anything else except questions on WEP. | |
| Use the following pieces of context to answer the question at the end. | |
| If you do not know the answer, say you do not know. | |
| {context} | |
| Question: {question} | |
| """ | |
| prompt = PromptTemplate( | |
| template=template, | |
| input_variables=["context", "question"] | |
| ) | |
| print("Created Prompt Template") | |
| memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) | |
| print("Created Memory") | |
| qa_with_source = RetrievalQA.from_chain_type( | |
| llm=llm, | |
| chain_type="stuff", | |
| retriever=store.as_retriever(), | |
| chain_type_kwargs={"prompt": prompt, }, | |
| return_source_documents=True, | |
| ) | |
| print("Instantiated Qa_with source. Running query now.") | |
| def format_chat_history(message, chat_history): | |
| formatted_chat_history = [] | |
| for user_message, bot_message in chat_history: | |
| formatted_chat_history.append(f"User: {user_message}") | |
| formatted_chat_history.append(f"Assistant: {bot_message}") | |
| return formatted_chat_history | |
| ''' | |
| def conversation(qa_chain, message, history): | |
| formatted_chat_history = format_chat_history(message, history) | |
| #response = qa_with_source.invoke({"question": message, "chat_history": formatted_chat_history}) | |
| response = qa_with_source.invoke(message) | |
| print(response) | |
| response_answer = response["result"] | |
| if response_answer.find("Helpful Answer:") != -1: | |
| response_answer = response_answer.split("Helpful Answer:")[-1] | |
| new_history = history + [(message, response_answer)] | |
| return qa_chain, gr.update(value=""), new_history | |
| ''' | |
| import re | |
| df = pd.read_csv('/home/user/app/xxx.csv', header=None) | |
| def extract_url_number(text): | |
| pattern = r'/content/htmls/(\d+).html' | |
| matches = re.findall(pattern, text) | |
| return matches | |
| def conversation(qa_chain, message, history, metadata_output): | |
| formatted_chat_history = format_chat_history(message, history) | |
| response = qa_with_source.invoke(message) | |
| print(response) | |
| response_answer = response["result"] | |
| if response_answer.find("Helpful Answer:") != -1: | |
| response_answer = response_answer.split("Helpful Answer:")[-1] | |
| new_history = history + [(message, response_answer)] | |
| documents = str(response['source_documents']) | |
| url_numbers = extract_url_number(documents) | |
| sources = [] | |
| for match in url_numbers: | |
| sources.append(df.iloc[int(match)][0]) | |
| sources = "\n".join([f" {s}" for s in sources]) | |
| return qa_chain, gr.update(value=""), new_history, sources | |
| import os | |
| import pickle | |
| def save_chat(chatbot): | |
| # Create a unique filename for the chat history | |
| filename = f"chat_history_{int(time.time())}.pkl" | |
| # Save the chat history to a pickle file | |
| with open(os.path.join("/home/user/app/chat_histories", filename), "wb") as f: | |
| pickle.dump(chatbot, f) | |
| print(f"Chat history saved to {filename}") | |
| return filename | |
| def demo(): | |
| with gr.Blocks(theme = gr.themes.Soft(), title = 'Multilingual RAG Demo') as demo: | |
| vector_db = gr.State() | |
| qa_chain = gr.State() | |
| collection_name = gr.State() | |
| gr.HTML( | |
| """ | |
| <div class="row"> | |
| <div class="col-md-4"><img src="https://static.wixstatic.com/media/8995e3_0854373386904d08bd64a70fa5134ea1~mv2.png/v1/fill/w_160,h_52,al_c,q_85,usm_0.66_1.00_0.01,enc_auto/Dono%20Consulting%20Logo.png"/></div> | |
| <div class="col-md-4" style="text-align: center"><h2>Multilingual RAG Demo</h2></div> | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(): | |
| lang_btn = gr.Dropdown(languages_list, label="Languages", value = languages_list[1],type="value", info="Choose your language",interactive = True) | |
| lang_btn.change(intitalize_lang, inputs = lang_btn) | |
| chatbot = gr.Chatbot(height=300, bubble_full_width = False, layout = 'panel') | |
| chatbot.change(preprocess = english_to_indian, postprocess = indian_to_english) | |
| with gr.Row(): | |
| msg = gr.Textbox(placeholder="Type message", container=True, label="Mesage Box") | |
| with gr.Row(): | |
| #source_docs_outputs = gr.Textbox(placeholder='References') | |
| metadata_output = gr.Textbox( label="Source") | |
| with gr.Row(): | |
| submit_btn = gr.Button("Submit") | |
| clear_btn = gr.ClearButton([msg, chatbot]) | |
| with gr.Row(): | |
| save_chat_btn = gr.Button("Save Chat") | |
| # Chatbot events | |
| msg.submit(conversation, inputs=[qa_chain, msg, chatbot, metadata_output], outputs=[qa_chain, msg, chatbot, metadata_output], queue=False) | |
| submit_btn.click(conversation, inputs=[qa_chain, msg, chatbot, metadata_output], outputs=[qa_chain, msg, chatbot, metadata_output], queue=False) | |
| #clear_btn.click(lambda:[None,"",0,"",0,"",0], inputs=None, outputs=[chatbot], queue=False) | |
| clear_btn.click(lambda: [], inputs=None, outputs=[chatbot], queue=False) | |
| save_chat_btn.click(save_chat, inputs=[chatbot], outputs=None, queue=False) | |
| demo.queue().launch(debug=True,auth=('dono_demo', 'pass1234'), favicon_path="/home/user/app/donologo.ico") | |
| if __name__ == "__main__": | |
| demo() |