fidelkim's picture
edited
9365de7 verified
import os
import keras_nlp
import gradio as gr
import re
import gc
import tensorflow as tf
from collections import deque
os.environ["KAGGLE_USERNAME"] = 'minkyuuukim'
os.environ["KAGGLE_KEY"] = os.getenv("KAGGLE_KEY")
os.environ["KERAS_BACKEND"] = "jax"
class ChatState:
__START_TURN_USER__ = "<start_of_turn>user\n"
__START_TURN_MODEL__ = "<start_of_turn>model\n"
__END_TURN__ = "<end_of_turn>\n"
def __init__(self,detection_num:int):
self.model = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_instruct_2b_en")
self.history=deque()
if detection_num:
self.model.backbone.load_lora_weights('./chatbot.lora.h5')
def add_to_history_as_user(self, message):
self.history.append(self.__START_TURN_USER__ + message + self.__END_TURN__)
def add_to_history_as_model(self, message):
self.history.append(self.__START_TURN_MODEL__ + message)
def get_history(self):
if len(self.history) > 6:
self.history.popleft()
return "".join([*self.history])
def get_full_prompt(self):
prompt = self.get_history()+ self.__START_TURN_MODEL__
prompt ="You are a chatbot. Engage in friendly conversations, answer questions, and ask relevant follow-up questions to keep the conversation going." + "\n" + prompt
return prompt
def send_message(self, message,detection_num):
self.add_to_history_as_user(message)
prompt = self.get_full_prompt()
if detection_num:
response = self.model.generate(f"Input:\n{prompt}\n\nResponse:\n", max_length=4096)
result = response.split("Response:\n")[1]
else:
response = self.model.generate(prompt, max_length=4096)
result = response.split("model\n")[-1]
self.add_to_history_as_model(result)
return result
class DetectState():
def __init__(self):
self.model =keras_nlp.models.GemmaCausalLM.from_preset("gemma2_instruct_2b_en")
self.model.backbone.load_lora_weights('./detection.lora.h5')
self.detection_num=0
def detect_message(self, message):
detection=self.model.generate(f"\n\nText:\n{message}\n\nQuestion:\nDoes the writer have depression?\n\nAnswer:\n", max_length=128)
match = re.search(r'Answer:\s*(-?\d)', detection)
if match:
self.detection_num = int(match.group(1))
else:
self.detection_num = 0
return int(self.detection_num)
diary_entries = []
dect_result=0
def save_diary_entry(entry):
global dect_result
global chat_state
diary_entries.append(entry)
ds = DetectState()
dect_result=ds.detect_message(diary_entries[-1])
del ds
gc.collect()
chat_state = ChatState(dect_result)
return "Your entry has been saved! Click the submit button on the right to start chatting."
def respond(
message,
history: list,
system_message,
):
global dect_result
global chat_state
if len(chat_state.history)==0:
response = chat_state.send_message(diary_entries[-1],dect_result)
else:
response = chat_state.send_message(message,dect_result)
yield response
with gr.Blocks() as demo:
gr.Markdown("# Diary and Chatbot Application")
chat_history = gr.State([])
with gr.Row():
with gr.Column(scale=3):
diary_input = gr.Textbox(label="Write today's diary", lines=10)
diary_output = gr.Textbox(label="Status message")
diary_button = gr.Button("Save diary")
diary_button.click(save_diary_entry, inputs=diary_input, outputs=diary_output)
with gr.Column(scale=2):
chatbot = gr.ChatInterface( fn=respond)
if __name__ == "__main__":
demo.launch(debug=True)