File size: 6,541 Bytes
94edee7
 
 
71cb2d7
 
 
a8332cf
94edee7
 
 
71cb2d7
 
 
94edee7
32683f5
71cb2d7
 
94edee7
bca6a8e
67f190a
bca6a8e
94edee7
9091fe3
94edee7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4534e98
 
94edee7
 
 
 
 
4534e98
94edee7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9091fe3
94edee7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9091fe3
94edee7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91c09af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b973f74
91c09af
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
import asyncio
import os
import threading
import random
from threading import Event
from typing import Optional
import json

import discord
import gradio as gr
from discord import Permissions
from discord.ext import commands
from discord.utils import oauth_url

from gradio_client import Client
import gradio_client as grc
from gradio_client.utils import QueueError

event = Event()
DISCORD_TOKEN = os.getenv("DISCORD_TOKEN")

HF_TOKEN = os.getenv("HF_TOKEN")
codellama_client = Client("https://huggingface-projects-codellama-13b-chat.hf.space/", HF_TOKEN)





codellama_threadid_userid_dictionary = {}
codellama_threadid_conversation = {}








intents = discord.Intents.default()
intents.message_content = True
bot = commands.Bot(command_prefix="/", intents=intents)

@bot.event
async def on_ready():
    print(f"Logged in as {bot.user} (ID: {bot.user.id})")
    synced = await bot.tree.sync()
    print(f"Synced commands: {', '.join([s.name for s in synced])}.")
    event.set()
    print("------")



@bot.hybrid_command(
    name="codellama",
    description="Enter a prompt to generate code!",
)
async def codellama(ctx, prompt: str):
    """Audioldm2 generation"""
    try:
        await try_codellama(ctx, prompt)
    except Exception as e:
        print(f"Error: (app.py){e}")



@bot.event
async def on_message(message):
    """Checks channel and continues codellama conversation if it's the right Discord Thread"""
    try:
        if not message.author.bot:
            await continue_codellama(message)
    except Exception as e:
        print(f"Error: {e}")








async def try_codellama(ctx, prompt):
    """Generates text based on a given prompt"""
    try:
        global codellama_threadid_userid_dictionary
        global codellama_threadid_conversation

        message = await ctx.send(f"**{prompt}** - {ctx.author.mention}")
        thread = await message.create_thread(name=prompt[:100])


        
        loop = asyncio.get_running_loop()
        output_code = await loop.run_in_executor(None, codellama_initial_generation, prompt, thread)
        codellama_threadid_userid_dictionary[thread.id] = ctx.author.id

        print(output_code)
        await thread.send(output_code)
    except Exception as e:
        print(f"try_codellama Error: {e}")





def codellama_initial_generation(prompt, thread):
    """job.submit inside of run_in_executor = more consistent bot behavior"""
    global codellama_threadid_conversation

    chat_history = f"{thread.id}.json"
    conversation = []
    with open(chat_history, "w") as json_file:
        json.dump(conversation, json_file)

    job = codellama_client.submit(prompt, chat_history, fn_index=0)

    while job.done() is False:
        pass
    else:
        result = job.outputs()[-1]
        with open(result, "r") as json_file:
            data = json.load(json_file)
        response = data[-1][-1]
        conversation.append((prompt, response))
        with open(chat_history, "w") as json_file:
            json.dump(conversation, json_file)

        codellama_threadid_conversation[thread.id] = chat_history
        if len(response) > 1300:
            response = response[:1300] + "...\nTruncating response due to discord api limits."
        return response























async def continue_codellama(message):
    """Continues a given conversation based on chat_history"""
    try:
        if not message.author.bot:
            global codellama_threadid_userid_dictionary  # tracks userid-thread existence
            if message.channel.id in codellama_threadid_userid_dictionary:  # is this a valid thread?
                if codellama_threadid_userid_dictionary[message.channel.id] == message.author.id:
                    global codellama_threadid_conversation

                    prompt = message.content
                    chat_history = codellama_threadid_conversation[message.channel.id]

                    # Check to see if conversation is ongoing or ended (>15000 characters)
                    with open(chat_history, "r") as json_file:
                        conversation = json.load(json_file)
                    total_characters = 0
                    for item in conversation:
                        for string in item:
                            total_characters += len(string)

                    if total_characters < 15000:
                        job = codellama_client.submit(prompt, chat_history, fn_index=0)
                        while job.done() is False:
                            pass
                        else:
                            result = job.outputs()[-1]
                            with open(result, "r") as json_file:
                                data = json.load(json_file)
                            response = data[-1][-1]
                            with open(chat_history, "r") as json_file:
                                conversation = json.load(json_file)
                            conversation.append((prompt, response))
                            with open(chat_history, "w") as json_file:
                                json.dump(conversation, json_file)
                            codellama_threadid_conversation[message.channel.id] = chat_history
                            
                            if len(response) > 1300:
                                response = response[:1300] + "...\nTruncating response due to discord api limits."

                            await message.reply(response)

                            total_characters = 0
                            for item in conversation:
                                for string in item:
                                    total_characters += len(string)

                            if total_characters >= 15000:
                                await message.reply("Conversation ending due to length, feel free to start a new one!")

    except Exception as e:
        print(f"continue_codellama Error: {e}")
#---------------------------------------------------------------------------------------------------------------------
def run_bot():
    if not DISCORD_TOKEN:
        print("DISCORD_TOKEN NOT SET")
        event.set()
    else:
        bot.run(DISCORD_TOKEN)


threading.Thread(target=run_bot).start()

event.wait()

with gr.Blocks() as demo:
    gr.Markdown(
    """
    # Discord bot of https://huggingface.co/spaces/facebook/MusicGen
    https://discord.com/api/oauth2/authorize?client_id=1152238037355474964&permissions=326417516544&scope=bot
    """
    )

demo.launch()