Spaces:
Runtime error
Runtime error
| import json | |
| import time | |
| from os.path import exists, normpath | |
| from pathlib import Path | |
| from typing import List, Dict | |
| import yaml | |
| class User: | |
| """ | |
| Class stored individual tg user info (history, message sequence, etc...) and provide some actions | |
| """ | |
| def __init__( | |
| self, | |
| char_file="", | |
| user_id=0, | |
| name1="You", | |
| name2="Bot", | |
| context="", | |
| example="", | |
| language="en", | |
| silero_speaker="None", | |
| silero_model_id="None", | |
| turn_template="", | |
| greeting="Hello.", | |
| ): | |
| """Init User class with default attribute | |
| Args: | |
| name1: username | |
| name2: current character name | |
| context: context of conversation, example: "Conversation between Bot and You" | |
| greeting: just greeting message from bot | |
| """ | |
| self.char_file: str = char_file | |
| self.user_id: int = user_id | |
| self.name1: str = name1 | |
| self.name2: str = name2 | |
| self.context: str = context | |
| self.example: str = example | |
| self.language: str = language | |
| self.silero_speaker: str = silero_speaker | |
| self.silero_model_id: str = silero_model_id | |
| self.turn_template: str = turn_template | |
| self.text_in: List[str] = [] # "user input history": ["Hi!","Who are you?"], need for regenerate option | |
| self.name_in: List[str] = [] # user_name history need to correct regenerate option | |
| self.history: List[Dict[str]] = [] # "history": [["in": "query1", "out": "answer1"],["in": "query2",... | |
| self.previous_history: Dict[str : List[str]] = {} # "previous_history": | |
| self.msg_id: List[int] = [] # "msg_id": [143, 144, 145, 146], | |
| self.greeting: str = greeting # "hello" or something | |
| self.last_msg_timestamp: int = 0 # last message timestamp to avoid message flood. | |
| def __or__(self, arg): | |
| return arg | |
| def history_last_out(self) -> str: | |
| return self.history[-1]["out"] | |
| def history_last_in(self) -> str: | |
| return self.history[-1]["in"] | |
| def truncate_last_message(self): | |
| """Truncate user history (minus one answer and user input) | |
| Returns: | |
| user_in: truncated user input string | |
| msg_id: truncated answer message id (to be deleted in chat) | |
| """ | |
| msg_id = self.msg_id.pop() | |
| user_in = self.text_in.pop() | |
| self.name_in.pop() | |
| self.history.pop() | |
| return user_in, msg_id | |
| def history_append(self, message="", answer=""): | |
| self.history.append({"in": message, "out": answer}) | |
| def history_as_str(self) -> str: | |
| history = "" | |
| if len(self.history) == 0: | |
| return history | |
| for s in self.history: | |
| if len(s["in"]) > 0: | |
| history += s["in"] | |
| if len(s["out"]) > 0: | |
| history += s["out"] | |
| return history | |
| def history_as_list(self) -> list: | |
| history_list = [] | |
| if len(self.history) == 0: | |
| return history_list | |
| for s in self.history: | |
| if len(s["in"]) > 0: | |
| history_list.append(s["in"]) | |
| if len(s["out"]) > 0: | |
| history_list.append(s["out"]) | |
| return history_list | |
| def change_last_message(self, text_in=None, name_in=None, history_in=None, history_out=None, msg_id=None): | |
| if text_in: | |
| self.text_in[-1] = text_in | |
| if name_in: | |
| self.name_in[-1] = name_in | |
| if history_in: | |
| self.history[-1]["in"] = history_in | |
| if history_out: | |
| self.history[-1]["out"] = history_out | |
| if msg_id: | |
| self.msg_id[-1] = msg_id | |
| def back_to_previous_out(self, msg_id): | |
| if str(msg_id) in self.previous_history: | |
| last_out = self.history_last_out | |
| new_out = self.previous_history[str(msg_id)].pop(-1) | |
| self.history[-1]["out"] = new_out | |
| self.previous_history[str(msg_id)].insert(0, last_out) | |
| return self.history_last_out | |
| else: | |
| return None | |
| def reset(self): | |
| """Clear bot history and reset to default everything but language, silero and chat_file.""" | |
| self.name1 = "You" | |
| self.name2 = "Bot" | |
| self.context = "" | |
| self.example = "" | |
| self.turn_template = "" | |
| self.text_in = [] | |
| self.name_in = [] | |
| self.history = [] | |
| self.previous_history = {} | |
| self.msg_id = [] | |
| self.greeting = "Hello." | |
| def to_json(self): | |
| """Convert user data to json string. | |
| Returns: | |
| user data as json string | |
| """ | |
| return json.dumps( | |
| { | |
| "char_file": self.char_file, | |
| "user_id": self.user_id, | |
| "name1": self.name1, | |
| "name2": self.name2, | |
| "context": self.context, | |
| "example": self.example, | |
| "language": self.language, | |
| "silero_speaker": self.silero_speaker, | |
| "silero_model_id": self.silero_model_id, | |
| "turn_template": self.turn_template, | |
| "text_in": self.text_in, | |
| "name_in": self.name_in, | |
| "history": self.history, | |
| "previous_history": self.previous_history, | |
| "msg_id": self.msg_id, | |
| "greeting": self.greeting, | |
| } | |
| ) | |
| def from_json(self, json_data: str): | |
| """Convert json string data to internal variables of User class | |
| Args: | |
| json_data: user json data string | |
| Returns: | |
| True if success, otherwise False | |
| """ | |
| data = json.loads(json_data) | |
| try: | |
| self.char_file = data["char_file"] if "char_file" in data else "" | |
| self.user_id = data["user_id"] if "user_id" in data else 0 | |
| self.name1 = data["name1"] if "name1" in data else "You" | |
| self.name2 = data["name2"] if "name2" in data else "Bot" | |
| self.context = data["context"] if "context" in data else "" | |
| self.example = data["example"] if "example" in data else "" | |
| self.language = data["language"] if "language" in data else "en" | |
| self.silero_speaker = data["silero_speaker"] if "silero_speaker" in data else "None" | |
| self.silero_model_id = data["silero_model_id"] if "silero_model_id" in data else "None" | |
| self.turn_template = data["turn_template"] if "turn_template" in data else "" | |
| self.text_in = data["text_in"] if "text_in" in data else [] | |
| self.name_in = data["name_in"] if "name_in" in data else [] | |
| self.history = data["history"] if "history" in data else [] | |
| self.previous_history = data["previous_history"] if "previous_history" in data else {} | |
| self.msg_id = data["msg_id"] if "msg_id" in data else [] | |
| self.greeting = data["greeting"] if "greeting" in data else "Hello." | |
| return True | |
| except Exception as exception: | |
| print("from_json", exception) | |
| return False | |
| def load_character_file(self, characters_dir_path: str, char_file: str): | |
| """Load character_file file. | |
| First, reset all internal user data to default | |
| Second, read character_file file as yaml or json and converts to internal User data | |
| Args: | |
| characters_dir_path: path to character dir | |
| char_file: name of character_file file | |
| Returns: | |
| True if success, otherwise False | |
| """ | |
| self.reset() | |
| # Copy default user data. If reading will fail - return default user data | |
| try: | |
| # Try to read character_file file. | |
| char_file_path = Path(f"{characters_dir_path}/{char_file}") | |
| with open(normpath(char_file_path), "r", encoding="utf-8") as user_file: | |
| if char_file.split(".")[-1] == "json": | |
| data = json.loads(user_file.read()) | |
| else: | |
| data = yaml.safe_load(user_file.read()) | |
| # load persona and scenario | |
| self.char_file = char_file | |
| if "user" in data: | |
| self.name1 = data["user"] | |
| if "bot" in data: | |
| self.name2 = data["bot"] | |
| if "you_name" in data: | |
| self.name1 = data["you_name"] | |
| if "char_name" in data: | |
| self.name2 = data["char_name"] | |
| if "name" in data: | |
| self.name2 = data["name"] | |
| if "turn_template" in data: | |
| self.turn_template = data["turn_template"] | |
| self.context = "" | |
| if "char_persona" in data: | |
| self.context += f"{self.name2}'s persona: {data['char_persona'].strip()}\n" | |
| if "context" in data: | |
| if data["context"].strip() not in self.context: | |
| self.context += f"{data['context'].strip()}\n" | |
| if "world_scenario" in data: | |
| if data["world_scenario"].strip() not in self.context: | |
| self.context += f"Scenario: {data['world_scenario'].strip()}\n" | |
| if "scenario" in data: | |
| if data["scenario"].strip() not in self.context: | |
| self.context += f"Scenario: {data['scenario'].strip()}\n" | |
| if "personality" in data: | |
| if data["personality"].strip() not in self.context: | |
| self.context += f"Personality: {data['personality'].strip()}\n" | |
| if "description" in data: | |
| if data["description"].strip() not in self.context: | |
| self.context += f"Description: {data['description'].strip()}\n" | |
| # add dialogue examples | |
| if "example_dialogue" in data: | |
| self.example = f"\n{data['example_dialogue'].strip()}\n" | |
| # add character_file greeting | |
| if "char_greeting" in data: | |
| self.greeting = data["char_greeting"].strip() | |
| if "first_mes" in data: | |
| self.greeting = data["first_mes"].strip() | |
| if "greeting" in data: | |
| self.greeting = data["greeting"].strip() | |
| self.context = self._replace_context_templates(self.context) | |
| self.greeting = self._replace_context_templates(self.greeting) | |
| self.example = self._replace_context_templates(self.example) | |
| self.msg_id = [] | |
| self.text_in = [] | |
| self.name_in = [] | |
| self.history = [] | |
| except Exception as exception: | |
| print("load_char_json_file", exception) | |
| finally: | |
| return self | |
| def _replace_context_templates(self, s: str) -> str: | |
| s = s.replace("{{char}}", self.name2) | |
| s = s.replace("{{user}}", self.name1) | |
| s = s.replace("{{Char}}", self.name2) | |
| s = s.replace("{{User}}", self.name1) | |
| s = s.replace("<BOT>", self.name2) | |
| s = s.replace("<USER>", self.name1) | |
| return s | |
| def find_and_load_user_char_history(self, chat_id, history_dir_path: str): | |
| """Find and load user chat history. History files searched by file name template: | |
| chat_id + char_file + .json (new template versions) | |
| chat_id + name2 + .json (old template versions) | |
| Args: | |
| chat_id: user id | |
| history_dir_path: path to history dir | |
| Returns: | |
| True user history found and loaded, otherwise False | |
| """ | |
| chat_id = str(chat_id) | |
| user_char_history_path = f"{history_dir_path}/{str(chat_id)}{self.char_file}.json" | |
| user_char_history_old_path = f"{history_dir_path}/{str(chat_id)}{self.name2}.json" | |
| if exists(user_char_history_path): | |
| return self.load_user_history(user_char_history_path) | |
| elif exists(user_char_history_old_path): | |
| return self.load_user_history(user_char_history_old_path) | |
| return False | |
| def load_user_history(self, file_path): | |
| """load history file data to User data | |
| Args: | |
| file_path: path to history file | |
| Returns: | |
| True user history loaded, otherwise False | |
| """ | |
| try: | |
| if exists(file_path): | |
| with open(normpath(file_path), "r", encoding="utf-8") as user_file: | |
| data = user_file.read() | |
| self.from_json(data) | |
| if self.char_file == "": | |
| self.char_file = self.name2 | |
| return True | |
| except Exception as exception: | |
| print(f"load_user_history: {exception}") | |
| return False | |
| def save_user_history(self, chat_id, history_dir_path="history"): | |
| """Save two history file "user + character_file + .json" and default user history files and return their path | |
| Args: | |
| chat_id: user chat_id | |
| history_dir_path: history dir path | |
| Returns: | |
| user_char_file_path, default_user_file_path | |
| """ | |
| if self.char_file == "": | |
| self.char_file = self.name2 | |
| user_data = self.to_json() | |
| user_char_file_path = Path(f"{history_dir_path}/{chat_id}{self.char_file}.json") | |
| with user_char_file_path.open("w", encoding="utf-8") as user_file: | |
| user_file.write(user_data) | |
| default_user_file_path = Path(f"{history_dir_path}/{chat_id}.json") | |
| with default_user_file_path.open("w", encoding="utf-8") as user_file: | |
| user_file.write(user_data) | |
| return str(user_char_file_path), str(default_user_file_path) | |
| def check_flooding(self, flood_avoid_delay=5.0): | |
| """just check if passed flood_avoid_delay between last timestamp and now and renew new timestamp if True | |
| Args: | |
| flood_avoid_delay: | |
| Returns: | |
| True or False | |
| """ | |
| if time.time() - flood_avoid_delay > self.last_msg_timestamp: | |
| self.last_msg_timestamp = time.time() | |
| return True | |
| else: | |
| return False | |