| class ChatState(): |
| """ |
| Manages the conversation history for a turn-based chatbot |
| Follows the turn-based conversation guidelines for the Gemma family of models |
| documented at https://ai.google.dev/gemma/docs/formatting |
| """ |
|
|
|
|
| __START_TURN_USER__ = "Instruction:\n" |
| __START_TURN_MODEL__ = "\n\nResponse:\n" |
| __END_TURN__ = "" |
|
|
|
|
| def __init__(self, model, system=""): |
| """ |
| Initializes the chat state. |
| |
| Args: |
| model: The language model to use for generating responses. |
| system: (Optional) System instructions or bot description. |
| """ |
| self.model = model |
| self.system = system |
| self.history = [] |
|
|
| def add_to_history_as_user(self, message): |
| """ |
| Adds a user message to the history with start/end turn markers. |
| """ |
| self.history.append(self.__START_TURN_USER__ + message + self.__END_TURN__) |
|
|
| def add_to_history_as_model(self, message): |
| """ |
| Adds a model response to the history with the start turn marker. |
| Model will generate end turn marker. |
| """ |
| self.history.append(self.__START_TURN_MODEL__ + message+ "\n") |
|
|
| def get_history(self): |
| """ |
| Returns the entire chat history as a single string. |
| """ |
| return "".join([*self.history]) |
|
|
| def get_full_prompt(self): |
| """ |
| Builds the prompt for the language model, including history and system description. |
| """ |
| prompt = self.get_history() + self.__START_TURN_MODEL__ |
| if len(self.system)>0: |
| prompt = self.system + "\n" + prompt |
| return prompt |
|
|
| def send_message(self, message): |
| """ |
| Handles sending a user message and getting a model response. |
| |
| Args: |
| message: The user's message. |
| |
| Returns: |
| The model's response. |
| """ |
| self.add_to_history_as_user(message) |
| prompt = self.get_full_prompt() |
| response = self.model.generate(prompt, max_length=4096) |
| result = response.replace(prompt, "") |
| self.add_to_history_as_model(result) |
| return result |
|
|