| |
| |
| import os |
|
|
| |
| import json |
| import requests |
|
|
| |
| import discord |
|
|
| |
| API_URL = 'https://api-inference.huggingface.co/models/r3dhummingbird/' |
|
|
| class MyClient(discord.Client): |
| def __init__(self, model_name): |
| super().__init__() |
| self.api_endpoint = API_URL + model_name |
| |
| huggingface_token = os.environ['HUGGINGFACE_TOKEN'] |
| |
| self.request_headers = { |
| 'Authorization': 'Bearer {}'.format(huggingface_token) |
| } |
|
|
| def query(self, payload): |
| """ |
| make request to the Hugging Face model API |
| """ |
| data = json.dumps(payload) |
| response = requests.request('POST', |
| self.api_endpoint, |
| headers=self.request_headers, |
| data=data) |
| ret = json.loads(response.content.decode('utf-8')) |
| return ret |
|
|
| async def on_ready(self): |
| |
| print('Logged in as') |
| print(self.user.name) |
| print(self.user.id) |
| print('------') |
| |
| |
| self.query({'inputs': {'text': 'Hello!'}}) |
|
|
| async def on_message(self, message): |
| """ |
| this function is called whenever the bot sees a message in a channel |
| """ |
| |
| if message.author.id == self.user.id: |
| return |
|
|
| |
| payload = {'inputs': {'text': message.content}} |
|
|
| |
| |
| async with message.channel.typing(): |
| response = self.query(payload) |
| bot_response = response.get('generated_text', None) |
| |
| |
| |
| if not bot_response: |
| if 'error' in response: |
| bot_response = '`Error: {}`'.format(response['error']) |
| else: |
| bot_response = 'Hmm... something is not right.' |
|
|
| |
| await message.channel.send(bot_response) |
|
|
| def main(): |
| |
| client = MyClient('DialoGPT-medium-joshua') |
| client.run(os.environ['DISCORD_TOKEN']) |
|
|
| if __name__ == '__main__': |
| main() |