| | import asyncio |
| | import json |
| | import os |
| |
|
| | from gradio_client import Client |
| |
|
| | HF_TOKEN = os.getenv("HF_TOKEN") |
| |
|
| | codellama = Client("https://huggingface-projects-codellama-13b-chat.hf.space/", HF_TOKEN) |
| |
|
| | BOT_USER_ID = 1102236653545861151 |
| | CODELLAMA_CHANNEL_ID = 1147210106321256508 |
| |
|
| |
|
| | codellama_threadid_userid_dictionary = {} |
| | codellama_threadid_conversation = {} |
| |
|
| |
|
| | def codellama_initial_generation(prompt, thread): |
| | """job.submit inside of run_in_executor = more consistent bot behavior""" |
| | global codellama_threadid_conversation |
| |
|
| | chat_history = f"{thread.id}.json" |
| | conversation = [] |
| | with open(chat_history, "w") as json_file: |
| | json.dump(conversation, json_file) |
| |
|
| | job = codellama.submit(prompt, chat_history, fn_index=0) |
| |
|
| | while job.done() is False: |
| | pass |
| | else: |
| | result = job.outputs()[-1] |
| | with open(result, "r") as json_file: |
| | data = json.load(json_file) |
| | response = data[-1][-1] |
| | conversation.append((prompt, response)) |
| | with open(chat_history, "w") as json_file: |
| | json.dump(conversation, json_file) |
| |
|
| | codellama_threadid_conversation[thread.id] = chat_history |
| | if len(response) > 1300: |
| | response = response[:1300] + "...\nTruncating response due to discord api limits." |
| | return response |
| |
|
| |
|
| | async def try_codellama(ctx, prompt): |
| | """Generates text based on a given prompt""" |
| | try: |
| | global codellama_threadid_userid_dictionary |
| | global codellama_threadid_conversation |
| |
|
| | if ctx.author.id != BOT_USER_ID: |
| | if ctx.channel.id == CODELLAMA_CHANNEL_ID: |
| | message = await ctx.send(f"**{prompt}** - {ctx.author.mention}") |
| | if len(prompt) > 99: |
| | small_prompt = prompt[:99] |
| | else: |
| | small_prompt = prompt |
| | thread = await message.create_thread(name=small_prompt, auto_archive_duration=60) |
| |
|
| | loop = asyncio.get_running_loop() |
| | output_code = await loop.run_in_executor(None, codellama_initial_generation, prompt, thread) |
| | codellama_threadid_userid_dictionary[thread.id] = ctx.author.id |
| |
|
| | print(output_code) |
| | await thread.send(output_code) |
| | except Exception as e: |
| | print(f"try_codellama Error: {e}") |
| | await ctx.send(f"Error: {e} <@811235357663297546> (try_codellama error)") |
| |
|
| |
|
| | async def continue_codellama(message): |
| | """Continues a given conversation based on chat_history""" |
| | try: |
| | if not message.author.bot: |
| | global codellama_threadid_userid_dictionary |
| | if message.channel.id in codellama_threadid_userid_dictionary: |
| | if codellama_threadid_userid_dictionary[message.channel.id] == message.author.id: |
| | print("Safetychecks passed for continue_codellama") |
| | global codellama_threadid_conversation |
| |
|
| | prompt = message.content |
| | chat_history = codellama_threadid_conversation[message.channel.id] |
| |
|
| | |
| | with open(chat_history, "r") as json_file: |
| | conversation = json.load(json_file) |
| | total_characters = 0 |
| | for item in conversation: |
| | for string in item: |
| | total_characters += len(string) |
| |
|
| | if total_characters < 15000: |
| | if os.environ.get("TEST_ENV") == "True": |
| | print("Running codellama.submit") |
| | job = codellama.submit(prompt, chat_history, fn_index=0) |
| | while job.done() is False: |
| | pass |
| | else: |
| | if os.environ.get("TEST_ENV") == "True": |
| | print("Continue_codellama job done") |
| |
|
| | result = job.outputs()[-1] |
| | with open(result, "r") as json_file: |
| | data = json.load(json_file) |
| | response = data[-1][-1] |
| |
|
| | with open(chat_history, "r") as json_file: |
| | conversation = json.load(json_file) |
| |
|
| | conversation.append((prompt, response)) |
| | |
| |
|
| | with open(chat_history, "w") as json_file: |
| | json.dump(conversation, json_file) |
| | if os.environ.get("TEST_ENV") == "True": |
| | print(prompt) |
| | print(response) |
| | print(conversation) |
| | print(chat_history) |
| |
|
| | codellama_threadid_conversation[message.channel.id] = chat_history |
| | if len(response) > 1300: |
| | response = response[:1300] + "...\nTruncating response due to discord api limits." |
| |
|
| | await message.reply(response) |
| |
|
| | total_characters = 0 |
| | for item in conversation: |
| | for string in item: |
| | total_characters += len(string) |
| |
|
| | if total_characters >= 15000: |
| | await message.reply("Conversation ending due to length, feel free to start a new one!") |
| |
|
| | except Exception as e: |
| | print(f"continue_codellama Error: {e}") |
| | await message.reply(f"Error: {e} <@811235357663297546> (continue_codellama error)") |
| |
|