Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import re, os, warnings | |
| from langchain import PromptTemplate, LLMChain | |
| from langchain.llms.base import LLM | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig | |
| from peft import LoraConfig, get_peft_model, PeftConfig, PeftModel | |
| warnings.filterwarnings("ignore") | |
| # initialize and load PEFT model and tokenizer | |
| def init_model_and_tokenizer(PEFT_MODEL): | |
| config = PeftConfig.from_pretrained(PEFT_MODEL) | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_compute_dtype=torch.float16, | |
| ) | |
| peft_base_model = AutoModelForCausalLM.from_pretrained( | |
| config.base_model_name_or_path, | |
| return_dict=True, | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| peft_model = PeftModel.from_pretrained(peft_base_model, PEFT_MODEL) | |
| peft_tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) | |
| peft_tokenizer.pad_token = peft_tokenizer.eos_token | |
| return peft_model, peft_tokenizer | |
| # custom LLM chain to generate answer from PEFT model for each query | |
| def init_llm_chain(peft_model, peft_tokenizer): | |
| class CustomLLM(LLM): | |
| def _call(self, prompt: str, stop=None, run_manager=None) -> str: | |
| device = "cuda:0" | |
| peft_encoding = peft_tokenizer(prompt, return_tensors="pt").to(device) | |
| peft_outputs = peft_model.generate(input_ids=peft_encoding.input_ids, generation_config=GenerationConfig(max_new_tokens=128, pad_token_id = peft_tokenizer.eos_token_id, \ | |
| eos_token_id = peft_tokenizer.eos_token_id, attention_mask = peft_encoding.attention_mask, \ | |
| temperature=0.4, top_p=0.6, repetition_penalty=1.3, num_return_sequences=1,)) | |
| peft_text_output = peft_tokenizer.decode(peft_outputs[0], skip_special_tokens=True) | |
| return peft_text_output | |
| def _llm_type(self) -> str: | |
| return "custom" | |
| llm = CustomLLM() | |
| template = """Answer the following question truthfully. | |
| If you don't know the answer, respond 'Sorry, I don't know the answer to this question.'. | |
| If the question is too complex, respond 'Kindly, consult a psychiatrist for further queries.'. | |
| Example Format: | |
| : question here | |
| : answer here | |
| Begin! | |
| : {query} | |
| :""" | |
| prompt = PromptTemplate(template=template, input_variables=["query"]) | |
| llm_chain = LLMChain(prompt=prompt, llm=llm) | |
| return llm_chain | |
| def user(user_message, history): | |
| return "", history + [[user_message, None]] | |
| def bot(history): | |
| if len(history) >= 2: | |
| query = history[-2][0] + "\n" + history[-2][1] + "\nHere, is the next QUESTION: " + history[-1][0] | |
| else: | |
| query = history[-1][0] | |
| bot_message = llm_chain.run(query) | |
| bot_message = post_process_chat(bot_message, query) | |
| history[-1][1] = "" | |
| history[-1][1] += bot_message | |
| return history | |
| def post_process_chat(bot_message, query): | |
| # Find the position of ": {query}" in the bot_response | |
| query_position = bot_message.find(f": {query}") | |
| if query_position != -1: | |
| # Extract the part of the response starting from ": {query}" | |
| response_part = bot_message[query_position + len(f": {query}"):].strip() | |
| last_period_position = response_part.rfind(".") | |
| if last_period_position != -1: | |
| # Extract the part of the response up to the last period | |
| new_response_part = response_part[:last_period_position + 1].strip() | |
| return new_response_part | |
| # Return the original response if ": {query}" is not found | |
| return bot_message | |
| model = "uzairsiddiqui/falcon-7b-sharded-bf16-finetuned-mental-health-conversational" | |
| peft_model, peft_tokenizer = init_model_and_tokenizer(PEFT_MODEL = model) | |
| with gr.Blocks() as demo: | |
| gr.HTML("""Welcome to Mental Health Conversational AI""") | |
| gr.Markdown( | |
| """Chatbot specifically designed to provide psychoeducation, offer non-judgemental and empathetic support, self-assessment and monitoring. | |
| Get instant response for any mental health related queries. If the chatbot seems you need external support, then it will respond appropriately.""" | |
| ) | |
| chatbot = gr.Chatbot() | |
| query = gr.Textbox(label="Type your query here, then press 'enter' and scroll up for response") | |
| clear = gr.Button(value="Clear Chat History!") | |
| llm_chain = init_llm_chain(peft_model, peft_tokenizer) | |
| query.submit(user, [query, chatbot], [query, chatbot], queue=False).then(bot, chatbot, chatbot) | |
| clear.click(lambda: None, None, chatbot, queue=False) | |
| if __name__ == "__main__": | |
| demo.queue().launch() | |