File size: 14,222 Bytes
fcaa164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import os
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, List, Optional

import discord
import httpx
from fastapi import FastAPI

from camel.bots.discord.discord_installation import DiscordInstallation
from camel.logger import get_logger
from camel.utils import api_keys_required, dependencies_required

from .discord_store import DiscordBaseInstallationStore

if TYPE_CHECKING:
    from discord import Message

logger = get_logger(__name__)

TOKEN_URL = "https://discord.com/api/oauth2/token"
USER_URL = "https://discord.com/api/users/@me"


class DiscordApp:
    r"""A class representing a Discord app that uses the `discord.py` library
    to interact with Discord servers.

    This bot can respond to messages in specific channels and only reacts to
    messages that mention the bot.

    Attributes:
        channel_ids (Optional[List[int]]): A list of allowed channel IDs. If
            provided, the bot will only respond to messages in these channels.
        token (Optional[str]): The Discord bot token used for authentication.
    """

    @dependencies_required('discord')
    @api_keys_required(
        [
            ("token", "DISCORD_BOT_TOKEN"),
        ]
    )
    def __init__(
        self,
        channel_ids: Optional[List[int]] = None,
        token: Optional[str] = None,
        client_id: Optional[str] = None,
        client_secret: Optional[str] = None,
        redirect_uri: Optional[str] = None,
        installation_store: Optional[DiscordBaseInstallationStore] = None,
        intents: Optional[discord.Intents] = None,
    ) -> None:
        r"""Initialize the DiscordApp instance by setting up the Discord client
        and event handlers.

        Args:
            channel_ids (Optional[List[int]]): A list of allowed channel IDs.
                The bot will only respond to messages in these channels if
                provided. (default: :obj:`None`)
            token (Optional[str]): The Discord bot token for authentication.
                If not provided, the token will be retrieved from the
                environment variable `DISCORD_TOKEN`. (default: :obj:`None`)
            client_id (str, optional): The client ID for Discord OAuth.
                (default: :obj:`None`)
            client_secret (Optional[str]): The client secret for Discord OAuth.
                (default: :obj:`None`)
            redirect_uri (str): The redirect URI for OAuth callbacks.
                (default: :obj:`None`)
            installation_store (DiscordAsyncInstallationStore): The database
                stores all information of all installations.
                (default: :obj:`None`)
            intents (discord.Intents): The Discord intents of this app.
                (default: :obj:`None`)

        Raises:
            ValueError: If the `DISCORD_BOT_TOKEN` is not found in environment
                variables.
        """
        self.token = token or os.getenv("DISCORD_BOT_TOKEN")
        self.channel_ids = channel_ids
        self.installation_store = installation_store

        if not intents:
            intents = discord.Intents.all()
            intents.message_content = True
            intents.guilds = True

        self._client = discord.Client(intents=intents)

        # Register event handlers
        self._client.event(self.on_ready)
        self._client.event(self.on_message)

        # OAuth flow
        self.client_id = client_id or os.getenv("DISCORD_CLIENT_ID")
        self.client_secret = client_secret or os.getenv(
            "DISCORD_CLIENT_SECRET"
        )
        self.redirect_uri = redirect_uri

        self.oauth_flow = bool(
            self.client_id
            and self.client_secret
            and self.redirect_uri
            and self.installation_store
        )

        self.app = FastAPI()

    async def start(self):
        r"""Asynchronously start the Discord bot using its token.

        This method starts the bot and logs into Discord asynchronously using
        the provided token. It should be awaited when used in an async
        environment.
        """
        await self._client.start(self.token)

    def run(self) -> None:
        r"""Start the Discord bot using its token.

        This method starts the bot and logs into Discord synchronously using
        the provided token. It blocks execution and keeps the bot running.
        """
        self._client.run(self.token)  # type: ignore[arg-type]

    async def exchange_code_for_token_response(
        self, code: str
    ) -> Optional[str]:
        r"""Exchange the authorization code for an access token.

        Args:
            code (str): The authorization code received from Discord after
                user authorization.

        Returns:
            Optional[str]: The access token if successful, otherwise None.

        Raises:
            ValueError: If OAuth configuration is incomplete or invalid.
            httpx.RequestError: If there is a network issue during the request.
        """
        if not self.oauth_flow:
            logger.warning(
                "OAuth is not enabled. Missing client_id, "
                "client_secret, or redirect_uri."
            )
            return None
        data = {
            "client_id": self.client_id,
            "client_secret": self.client_secret,
            "grant_type": "authorization_code",
            "code": code,
            "redirect_uri": self.redirect_uri,
        }
        headers = {"Content-Type": "application/x-www-form-urlencoded"}
        try:
            async with httpx.AsyncClient() as client:
                response = await client.post(
                    TOKEN_URL, data=data, headers=headers
                )
                if response.status_code != 200:
                    logger.error(f"Failed to exchange code: {response.text}")
                    return None
                response_data = response.json()

                return response_data
        except (httpx.RequestError, ValueError) as e:
            logger.error(f"Error during token fetch: {e}")
            return None

    async def get_user_info(self, access_token: str) -> Optional[dict]:
        r"""Retrieve user information using the access token.

        Args:
            access_token (str): The access token received from Discord.

        Returns:
            dict: The user information retrieved from Discord.
        """
        if not self.oauth_flow:
            logger.warning(
                "OAuth is not enabled. Missing client_id, "
                "client_secret, or redirect_uri."
            )
            return None
        headers = {"Authorization": f"Bearer {access_token}"}
        async with httpx.AsyncClient() as client:
            user_response = await client.get(USER_URL, headers=headers)
            return user_response.json()

    async def refresh_access_token(self, refresh_token: str) -> Optional[str]:
        r"""Refresh the access token using a refresh token.

        Args:
            refresh_token (str): The refresh token issued by Discord that
                can be used to obtain a new access token.

        Returns:
            Optional[str]: The new access token if successful, otherwise None.
        """
        if not self.oauth_flow:
            logger.warning(
                "OAuth is not enabled. Missing client_id, "
                "client_secret, or redirect_uri."
            )
            return None
        data = {
            "client_id": self.client_id,
            "client_secret": self.client_secret,
            "grant_type": "refresh_token",
            "refresh_token": refresh_token,
            "redirect_uri": self.redirect_uri,
        }
        headers = {"Content-Type": "application/x-www-form-urlencoded"}
        async with httpx.AsyncClient() as client:
            response = await client.post(TOKEN_URL, data=data, headers=headers)
            if response.status_code != 200:
                logger.error(f"Failed to refresh token: {response.text}")
                return None
            response_data = response.json()
            return response_data.get("access_token")

    async def get_valid_access_token(self, guild_id: str) -> Optional[str]:
        r"""Retrieve a valid access token for the specified guild.

        This method attempts to retrieve an access token for a specific guild.
        If the current access token is expired, it will refresh the token using
        the refresh token.

        Args:
            guild_id (str): The ID of the guild to retrieve the access
                token for.

        Returns:
            Optional[str]: The valid access token if successful,
                otherwise None.
        """
        if not self.oauth_flow:
            logger.warning(
                "OAuth is not enabled. Missing client_id, "
                "client_secret, or redirect_uri."
            )
            return None
        assert self.installation_store is not None
        installation = await self.installation_store.find_by_guild(
            guild_id=guild_id
        )
        if not installation:
            logger.error(f"No installation found for guild: {guild_id}")
            return None

        if (
            installation.token_expires_at
            and datetime.now() >= installation.token_expires_at
        ):
            logger.info(
                f"Access token expired for guild: {guild_id}, "
                f"refreshing token..."
            )
            new_access_token = await self.refresh_access_token(
                installation.refresh_token
            )
            if new_access_token:
                installation.access_token = new_access_token
                installation.token_expires_at = datetime.now() + timedelta(
                    seconds=3600
                )
                await self.installation_store.save(installation)
                return new_access_token
            else:
                logger.error(
                    f"Failed to refresh access token for guild: {guild_id}"
                )
                return None

        return installation.access_token

    async def save_installation(
        self,
        guild_id: str,
        access_token: str,
        refresh_token: str,
        expires_in: int,
    ):
        r"""Save the installation information for a given guild.

        Args:
            guild_id (str): The ID of the guild where the bot is installed.
            access_token (str): The access token for the guild.
            refresh_token (str): The refresh token for the guild.
            expires_in: (int): The expiration time of the
                access token.
        """
        if not self.oauth_flow:
            logger.warning(
                "OAuth is not enabled. Missing client_id, "
                "client_secret, or redirect_uri."
            )
            return None
        assert self.installation_store is not None
        expires_at = datetime.now() + timedelta(seconds=expires_in)
        installation = DiscordInstallation(
            guild_id=guild_id,
            access_token=access_token,
            refresh_token=refresh_token,
            installed_at=datetime.now(),
            token_expires_at=expires_at,
        )
        await self.installation_store.save(installation)
        logger.info(f"Installation saved for guild: {guild_id}")

    async def remove_installation(self, guild: discord.Guild):
        r"""Remove the installation for a given guild.

        Args:
            guild (discord.Guild): The guild from which the bot is
                being removed.
        """
        if not self.oauth_flow:
            logger.warning(
                "OAuth is not enabled. Missing client_id, "
                "client_secret, or redirect_uri."
            )
            return None
        assert self.installation_store is not None
        await self.installation_store.delete(guild_id=str(guild.id))
        print(f"Bot removed from guild: {guild.id}")

    async def on_ready(self) -> None:
        r"""Event handler that is called when the bot has successfully
        connected to the Discord server.

        When the bot is ready and logged into Discord, it prints a message
        displaying the bot's username.
        """
        logger.info(f'We have logged in as {self._client.user}')

    async def on_message(self, message: 'Message') -> None:
        r"""Event handler for processing incoming messages.

        This method is called whenever a new message is received by the bot. It
        will ignore messages sent by the bot itself, only respond to messages
        in allowed channels (if specified), and only to messages that mention
        the bot.

        Args:
            message (discord.Message): The message object received from
                Discord.
        """
        # If the message author is the bot itself,
        # do not respond to this message
        if message.author == self._client.user:
            return

        # If allowed channel IDs are provided,
        # only respond to messages in those channels
        if self.channel_ids and message.channel.id not in self.channel_ids:
            return

        # Only respond to messages that mention the bot
        if not self._client.user or not self._client.user.mentioned_in(
            message
        ):
            return

        logger.info(f"Received message: {message.content}")

    @property
    def client(self):
        return self._client