|
|
| import os |
| import keras_nlp |
| import gradio as gr |
| import re |
| import gc |
| import tensorflow as tf |
| from collections import deque |
|
|
|
|
| os.environ["KAGGLE_USERNAME"] = 'minkyuuukim' |
| os.environ["KAGGLE_KEY"] = os.getenv("KAGGLE_KEY") |
| os.environ["KERAS_BACKEND"] = "jax" |
|
|
| class ChatState: |
| __START_TURN_USER__ = "<start_of_turn>user\n" |
| __START_TURN_MODEL__ = "<start_of_turn>model\n" |
| __END_TURN__ = "<end_of_turn>\n" |
|
|
| def __init__(self,detection_num:int): |
| self.model = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_instruct_2b_en") |
| self.history=deque() |
| if detection_num: |
| self.model.backbone.load_lora_weights('./chatbot.lora.h5') |
|
|
|
|
| def add_to_history_as_user(self, message): |
| self.history.append(self.__START_TURN_USER__ + message + self.__END_TURN__) |
|
|
| def add_to_history_as_model(self, message): |
| self.history.append(self.__START_TURN_MODEL__ + message) |
|
|
| def get_history(self): |
| if len(self.history) > 6: |
| self.history.popleft() |
| return "".join([*self.history]) |
|
|
| def get_full_prompt(self): |
| prompt = self.get_history()+ self.__START_TURN_MODEL__ |
|
|
| prompt ="You are a chatbot. Engage in friendly conversations, answer questions, and ask relevant follow-up questions to keep the conversation going." + "\n" + prompt |
| return prompt |
|
|
|
|
| def send_message(self, message,detection_num): |
| self.add_to_history_as_user(message) |
| prompt = self.get_full_prompt() |
|
|
| if detection_num: |
| response = self.model.generate(f"Input:\n{prompt}\n\nResponse:\n", max_length=4096) |
| result = response.split("Response:\n")[1] |
| else: |
| response = self.model.generate(prompt, max_length=4096) |
| result = response.split("model\n")[-1] |
| self.add_to_history_as_model(result) |
| return result |
|
|
| class DetectState(): |
| def __init__(self): |
| self.model =keras_nlp.models.GemmaCausalLM.from_preset("gemma2_instruct_2b_en") |
| self.model.backbone.load_lora_weights('./detection.lora.h5') |
| self.detection_num=0 |
|
|
| def detect_message(self, message): |
| detection=self.model.generate(f"\n\nText:\n{message}\n\nQuestion:\nDoes the writer have depression?\n\nAnswer:\n", max_length=128) |
| match = re.search(r'Answer:\s*(-?\d)', detection) |
| if match: |
| self.detection_num = int(match.group(1)) |
| else: |
| self.detection_num = 0 |
| return int(self.detection_num) |
|
|
|
|
| diary_entries = [] |
| dect_result=0 |
|
|
|
|
| def save_diary_entry(entry): |
| global dect_result |
| global chat_state |
| diary_entries.append(entry) |
| ds = DetectState() |
| dect_result=ds.detect_message(diary_entries[-1]) |
| del ds |
| gc.collect() |
| chat_state = ChatState(dect_result) |
| return "Your entry has been saved! Click the submit button on the right to start chatting." |
|
|
| def respond( |
| message, |
| history: list, |
| system_message, |
| ): |
| global dect_result |
| global chat_state |
| if len(chat_state.history)==0: |
| response = chat_state.send_message(diary_entries[-1],dect_result) |
| else: |
| response = chat_state.send_message(message,dect_result) |
| yield response |
|
|
|
|
|
|
| with gr.Blocks() as demo: |
| gr.Markdown("# Diary and Chatbot Application") |
|
|
| chat_history = gr.State([]) |
| with gr.Row(): |
| with gr.Column(scale=3): |
| diary_input = gr.Textbox(label="Write today's diary", lines=10) |
| diary_output = gr.Textbox(label="Status message") |
| diary_button = gr.Button("Save diary") |
| diary_button.click(save_diary_entry, inputs=diary_input, outputs=diary_output) |
|
|
| with gr.Column(scale=2): |
| chatbot = gr.ChatInterface( fn=respond) |
|
|
| if __name__ == "__main__": |
| demo.launch(debug=True) |