jacksonstrut commited on
Commit
9e9401b
·
verified ·
1 Parent(s): 5ade708

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -53
app.py CHANGED
@@ -1,21 +1,21 @@
1
  import os
2
  import random
3
  import logging
4
- import asyncio
5
  from twitchio.ext import commands
6
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
7
 
8
  # Set up logging
9
  logging.basicConfig(level=logging.INFO)
10
  logger = logging.getLogger(__name__)
11
 
12
- # Get Twitch credentials and Hugging Face API token from environment variables
13
  TWITCH_OAUTH_TOKEN = os.getenv('TWITCH_OAUTH_TOKEN')
14
  TWITCH_CHANNEL_NAME = os.getenv('TWITCH_CHANNEL_NAME')
15
  TWITCH_BOT_USERNAME = os.getenv('TWITCH_BOT_USERNAME')
16
  HUGGINGFACE_API_TOKEN = os.getenv('HUGGINGFACE_API_TOKEN')
17
- MAX_TOKENS = int(os.getenv('MAX_TOKENS', 50))
18
- TEMPERATURE = float(os.getenv('TEMPERATURE', 0.9))
19
 
20
  # Validate environment variables
21
  required_vars = [
@@ -29,14 +29,11 @@ missing_vars = [var for var in required_vars if not globals().get(var)]
29
  if missing_vars:
30
  raise ValueError(f"Missing environment variables: {', '.join(missing_vars)}")
31
 
32
- # Initialize the Hugging Face tokenizer and model
33
- model_name = "google/flan-t5-small" # or "t5-small" for an even lighter model
34
  tokenizer = AutoTokenizer.from_pretrained(model_name, token=HUGGINGFACE_API_TOKEN)
35
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name, token=HUGGINGFACE_API_TOKEN)
36
-
37
- # Ensure the model runs on CPU
38
- device = 'cpu'
39
- model.to(device)
40
 
41
  # List of house music hooks to drop randomly
42
  HOUSE_MUSIC_HOOKS = [
@@ -139,47 +136,53 @@ HOUSE_MUSIC_HOOKS = [
139
  "The music's got us moving, can't stop dancing!",
140
  ]
141
 
142
- # Define the response function to generate a reply
143
- async def generate_response(prompt):
144
- """Generates a response using the FLAN-T5 model."""
145
- # Adjusted prompt for better guidance
146
- guided_prompt = (
147
- f"You are a friendly and entertaining chatbot with the personality of an old-school raver who loves house music, good vibes, and funky beats. "
148
- f"Respond to the user's message in a groovy, laid-back, and full-of-love style.\n\n"
149
- f"User: {prompt}\nBot:"
150
- )
151
 
 
 
152
  try:
153
- inputs = tokenizer(guided_prompt, return_tensors="pt", max_length=512, truncation=True)
154
- inputs = inputs.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
- outputs = model.generate(
157
- **inputs,
158
- max_length=MAX_TOKENS + inputs['input_ids'].shape[1],
 
159
  temperature=TEMPERATURE,
160
  do_sample=True,
161
- top_p=0.9,
162
  top_k=50,
163
- num_beams=1,
164
- num_return_sequences=1,
165
- no_repeat_ngram_size=2,
166
  )
167
- response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
168
 
169
- # Log the full model output for debugging
170
- logger.info(f"Full model output: {response_text}")
 
171
 
172
- # Extract the bot's response
173
- response = response_text.split("Bot:", 1)[-1].strip()
174
- if not response:
175
- response = response_text.strip()
176
 
177
- # Randomly include a house hook (30% chance)
178
  if random.random() < 0.3:
179
- response = f"{random.choice(HOUSE_MUSIC_HOOKS)} {response}"
 
 
 
180
 
181
- logger.info(f"Generated response: {response}")
182
- return response
183
  except Exception as e:
184
  logger.error(f"Error generating response: {e}")
185
  return "Sorry, I'm too hyped to respond right now!"
@@ -202,23 +205,17 @@ class TwitchChatBot(commands.Bot):
202
 
203
  async def event_message(self, message):
204
  """Event handler when a message is received in chat."""
205
- if message.author and message.author.name.lower() == self.nick.lower():
206
- return # Ignore the bot's own messages
 
207
 
208
- # Safely log the message
209
- if message.author:
210
- logger.info(f"Message received from {message.author.name}: {message.content}")
211
- else:
212
- logger.info(f"Message received: {message.content}")
213
 
214
  # Generate a response
215
- response = await generate_response(message.content)
216
 
217
  # Send the response back to the Twitch chat
218
- if message.author:
219
- await message.channel.send(f"@{message.author.name} {response}")
220
- else:
221
- await message.channel.send(response)
222
 
223
  # Initialize and run the bot
224
  if __name__ == "__main__":
 
1
  import os
2
  import random
3
  import logging
 
4
  from twitchio.ext import commands
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
+ import torch
7
 
8
  # Set up logging
9
  logging.basicConfig(level=logging.INFO)
10
  logger = logging.getLogger(__name__)
11
 
12
+ # Credentials and settings
13
  TWITCH_OAUTH_TOKEN = os.getenv('TWITCH_OAUTH_TOKEN')
14
  TWITCH_CHANNEL_NAME = os.getenv('TWITCH_CHANNEL_NAME')
15
  TWITCH_BOT_USERNAME = os.getenv('TWITCH_BOT_USERNAME')
16
  HUGGINGFACE_API_TOKEN = os.getenv('HUGGINGFACE_API_TOKEN')
17
+ MAX_TOKENS = int(os.getenv('MAX_TOKENS', 100))
18
+ TEMPERATURE = float(os.getenv('TEMPERATURE', 0.7))
19
 
20
  # Validate environment variables
21
  required_vars = [
 
29
  if missing_vars:
30
  raise ValueError(f"Missing environment variables: {', '.join(missing_vars)}")
31
 
32
+ # Initialize the Hugging Face tokenizer and model for DialoGPT
33
+ model_name = "microsoft/DialoGPT-small"
34
  tokenizer = AutoTokenizer.from_pretrained(model_name, token=HUGGINGFACE_API_TOKEN)
35
+ model = AutoModelForCausalLM.from_pretrained(model_name, token=HUGGINGFACE_API_TOKEN)
36
+ model.to('cpu')
 
 
 
37
 
38
  # List of house music hooks to drop randomly
39
  HOUSE_MUSIC_HOOKS = [
 
136
  "The music's got us moving, can't stop dancing!",
137
  ]
138
 
139
+ # Initialize chat history for users
140
+ chat_histories = {}
 
 
 
 
 
 
 
141
 
142
+ async def generate_response(user_id, user_message):
143
+ """Generates a response using the DialoGPT model."""
144
  try:
145
+ # Retrieve or initialize the chat history for the user
146
+ if user_id in chat_histories:
147
+ chat_history_ids = chat_histories[user_id]
148
+ else:
149
+ chat_history_ids = None
150
+
151
+ # Encode the user message and append the EOS token
152
+ new_user_input_ids = tokenizer.encode(user_message + tokenizer.eos_token, return_tensors='pt').to('cpu')
153
+
154
+ # Concatenate new user input with chat history (if it exists)
155
+ if chat_history_ids is not None:
156
+ bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
157
+ else:
158
+ bot_input_ids = new_user_input_ids
159
 
160
+ # Generate a response
161
+ output_ids = model.generate(
162
+ bot_input_ids,
163
+ max_length=bot_input_ids.shape[-1] + MAX_TOKENS,
164
  temperature=TEMPERATURE,
165
  do_sample=True,
166
+ top_p=0.95,
167
  top_k=50,
168
+ pad_token_id=tokenizer.eos_token_id,
169
+ no_repeat_ngram_size=3,
 
170
  )
 
171
 
172
+ # Extract the new response
173
+ response_ids = output_ids[:, bot_input_ids.shape[-1]:]
174
+ response_text = tokenizer.decode(response_ids[0], skip_special_tokens=True)
175
 
176
+ # Update the chat history
177
+ chat_histories[user_id] = output_ids[:, -1000:] # Keep last 1000 tokens to limit history size
 
 
178
 
179
+ # Randomly include a house music hook (30% chance)
180
  if random.random() < 0.3:
181
+ response_text = f"{random.choice(HOUSE_MUSIC_HOOKS)} {response_text}"
182
+
183
+ logger.info(f"Generated response: {response_text}")
184
+ return response_text
185
 
 
 
186
  except Exception as e:
187
  logger.error(f"Error generating response: {e}")
188
  return "Sorry, I'm too hyped to respond right now!"
 
205
 
206
  async def event_message(self, message):
207
  """Event handler when a message is received in chat."""
208
+ # Ignore messages sent by the bot itself
209
+ if message.echo:
210
+ return
211
 
212
+ logger.info(f"Message received from {message.author.name}: {message.content}")
 
 
 
 
213
 
214
  # Generate a response
215
+ response = await generate_response(message.author.id, message.content)
216
 
217
  # Send the response back to the Twitch chat
218
+ await message.channel.send(f"@{message.author.name} {response}")
 
 
 
219
 
220
  # Initialize and run the bot
221
  if __name__ == "__main__":