| |
|
|
|
|
| from collections import deque |
| from fastai.text.all import load_learner |
| import sys |
| import re |
|
|
| |
| MAX_HISTORY_CHARS = 800 |
| MAX_HISTORY_MESSAGES = 1 |
| GENERATE_TOKENS = 70 |
| TEMPERATURE = 0.3 |
|
|
| def evaluate_placeholders(text: str) -> str: |
| def repl(match): |
| expr = match.group(1) |
| try: |
| if re.fullmatch(r"[\d\s\+\-\*\/]+", expr): |
| return str(eval(expr)) |
| except Exception: |
| pass |
| return match.group(0) |
|
|
| return re.sub(r"\{([^{}]+)\}", repl, text) |
|
|
| def remove_before_first_colon(s: str) -> str: |
| return s.split("BOT :", 1)[-1] |
| def remove_before_last_colon(s: str) -> str: |
| return s.rsplit(":", 1)[-1] |
| def remove_after_user(text): |
| keyword = "USER" |
| index = text.find(keyword) |
| if index != -1: |
| return text[:index + len(keyword)] |
| return text |
| def remove_after_bot(text): |
| keyword = "BOT" |
| index = text.find(keyword) |
| if index != -1: |
| return text[:index + len(keyword)] |
| return text |
|
|
| def truncate(answer): |
| for sep in ["\n", "USER:", "BOT:"]: |
| if sep in answer: |
| answer = answer.split(sep)[0] |
|
|
| answer = remove_before_first_colon(answer) |
| answer = remove_after_user(answer) |
| answer = remove_after_bot(answer) |
|
|
| answer = answer.replace(": USER", "").replace(" USER", "").replace("USER", "").replace(" !", "!").replace(" .", ".").replace(" ,", ",").replace(": BOT", "").replace(" BOT", "").replace("BOT", "").replace(" `", "`").replace(' "', '"').replace(" β", "β").replace("do n'", "don'").replace("do nβ", "donβ") |
| answer = answer.replace(" '", "'").replace(" :", ":").replace(" (", "(").replace(" )", ")").replace(" ?", "?").replace("Open Assistant", "Bomba-1") |
|
|
| return answer.strip() |
|
|
| def load_models(): |
| print("π€ Loading modelsβ¦") |
| chat_model = load_learner("model/SimpleMath.pkl") |
| chat_model.model.eval() |
| return chat_model |
|
|
| def main(): |
| chat_model = load_models() |
| history = deque() |
| print("π¬ Ready! (empty line to quit)\n") |
|
|
| while True: |
| try: |
| user = input("USER: ").strip() |
| if not user: |
| break |
|
|
| history.append(f"USER: {user}") |
| while len(history) > MAX_HISTORY_MESSAGES: |
| history.popleft() |
|
|
| prompt_lines = list(history) |
| prompt_text = " ".join(history).replace("\n"," ") |
| if len(prompt_text) > MAX_HISTORY_CHARS: |
| prompt_text = prompt_text[-MAX_HISTORY_CHARS:] |
| prompt = f"{prompt_text} BOT: " |
|
|
| generated = chat_model.predict( |
| prompt, |
| n_words=GENERATE_TOKENS, |
| temperature=TEMPERATURE, |
| min_p=0.01 |
| ) |
| |
| try: |
| _, raw = generated.split(prompt, 1) |
| except ValueError: |
| raw = generated |
|
|
| raw = raw.strip() |
| if raw.upper().startswith("USER:") and "BOT:" in raw: |
| raw = raw.split("BOT:", 1)[1].strip() |
|
|
| answer = truncate(raw) |
| answer = evaluate_placeholders(answer) |
| answer = answer.replace("-", "\n-").replace("1)", "\n1)").replace("2)", "\n2)").replace("3)", "\n3)").replace("4)", "\n4)").replace("5)", "\n5)").replace("* ", "\n* ").replace("Final", "\nFinal") |
| if not "Final" in answer: |
| answer = answer.replace("Result", "\nResult") |
| print("BOT:", answer, "\n") |
| history.append(f"BOT: {answer}") |
|
|
| except KeyboardInterrupt: |
| break |
|
|
| if __name__ == "__main__": |
| main() |