import asyncio import os import threading import random from threading import Event from typing import Optional import json import discord import gradio as gr from discord import Permissions from discord.ext import commands from discord.utils import oauth_url from gradio_client import Client import gradio_client as grc from gradio_client.utils import QueueError event = Event() DISCORD_TOKEN = os.getenv("DISCORD_TOKEN") HF_TOKEN = os.getenv("HF_TOKEN") codellama_client = Client("https://huggingface-projects-codellama-13b-chat.hf.space/", HF_TOKEN) codellama_threadid_userid_dictionary = {} codellama_threadid_conversation = {} intents = discord.Intents.default() intents.message_content = True bot = commands.Bot(command_prefix="/", intents=intents) @bot.event async def on_ready(): print(f"Logged in as {bot.user} (ID: {bot.user.id})") synced = await bot.tree.sync() print(f"Synced commands: {', '.join([s.name for s in synced])}.") event.set() print("------") @bot.hybrid_command( name="codellama", description="Enter a prompt to generate code!", ) async def codellama(ctx, prompt: str): """Audioldm2 generation""" try: await try_codellama(ctx, prompt) except Exception as e: print(f"Error: (app.py){e}") @bot.event async def on_message(message): """Checks channel and continues codellama conversation if it's the right Discord Thread""" try: if not message.author.bot: await continue_codellama(message) except Exception as e: print(f"Error: {e}") async def try_codellama(ctx, prompt): """Generates text based on a given prompt""" try: global codellama_threadid_userid_dictionary global codellama_threadid_conversation message = await ctx.send(f"**{prompt}** - {ctx.author.mention}") thread = await message.create_thread(name=prompt[:100]) loop = asyncio.get_running_loop() output_code = await loop.run_in_executor(None, codellama_initial_generation, prompt, thread) codellama_threadid_userid_dictionary[thread.id] = ctx.author.id print(output_code) await thread.send(output_code) except Exception as e: print(f"try_codellama Error: {e}") def codellama_initial_generation(prompt, thread): """job.submit inside of run_in_executor = more consistent bot behavior""" global codellama_threadid_conversation chat_history = f"{thread.id}.json" conversation = [] with open(chat_history, "w") as json_file: json.dump(conversation, json_file) job = codellama_client.submit(prompt, chat_history, fn_index=0) while job.done() is False: pass else: result = job.outputs()[-1] with open(result, "r") as json_file: data = json.load(json_file) response = data[-1][-1] conversation.append((prompt, response)) with open(chat_history, "w") as json_file: json.dump(conversation, json_file) codellama_threadid_conversation[thread.id] = chat_history if len(response) > 1300: response = response[:1300] + "...\nTruncating response due to discord api limits." return response async def continue_codellama(message): """Continues a given conversation based on chat_history""" try: if not message.author.bot: global codellama_threadid_userid_dictionary # tracks userid-thread existence if message.channel.id in codellama_threadid_userid_dictionary: # is this a valid thread? if codellama_threadid_userid_dictionary[message.channel.id] == message.author.id: global codellama_threadid_conversation prompt = message.content chat_history = codellama_threadid_conversation[message.channel.id] # Check to see if conversation is ongoing or ended (>15000 characters) with open(chat_history, "r") as json_file: conversation = json.load(json_file) total_characters = 0 for item in conversation: for string in item: total_characters += len(string) if total_characters < 15000: job = codellama_client.submit(prompt, chat_history, fn_index=0) while job.done() is False: pass else: result = job.outputs()[-1] with open(result, "r") as json_file: data = json.load(json_file) response = data[-1][-1] with open(chat_history, "r") as json_file: conversation = json.load(json_file) conversation.append((prompt, response)) with open(chat_history, "w") as json_file: json.dump(conversation, json_file) codellama_threadid_conversation[message.channel.id] = chat_history if len(response) > 1300: response = response[:1300] + "...\nTruncating response due to discord api limits." await message.reply(response) total_characters = 0 for item in conversation: for string in item: total_characters += len(string) if total_characters >= 15000: await message.reply("Conversation ending due to length, feel free to start a new one!") except Exception as e: print(f"continue_codellama Error: {e}") #--------------------------------------------------------------------------------------------------------------------- def run_bot(): if not DISCORD_TOKEN: print("DISCORD_TOKEN NOT SET") event.set() else: bot.run(DISCORD_TOKEN) threading.Thread(target=run_bot).start() event.wait() with gr.Blocks() as demo: gr.Markdown( """ # Discord bot of https://huggingface.co/spaces/facebook/MusicGen https://discord.com/api/oauth2/authorize?client_id=1152238037355474964&permissions=326417516544&scope=bot """ ) demo.launch()