Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,21 +1,21 @@
|
|
| 1 |
import os
|
| 2 |
import random
|
| 3 |
import logging
|
| 4 |
-
import asyncio
|
| 5 |
from twitchio.ext import commands
|
| 6 |
-
from transformers import AutoTokenizer,
|
|
|
|
| 7 |
|
| 8 |
# Set up logging
|
| 9 |
logging.basicConfig(level=logging.INFO)
|
| 10 |
logger = logging.getLogger(__name__)
|
| 11 |
|
| 12 |
-
#
|
| 13 |
TWITCH_OAUTH_TOKEN = os.getenv('TWITCH_OAUTH_TOKEN')
|
| 14 |
TWITCH_CHANNEL_NAME = os.getenv('TWITCH_CHANNEL_NAME')
|
| 15 |
TWITCH_BOT_USERNAME = os.getenv('TWITCH_BOT_USERNAME')
|
| 16 |
HUGGINGFACE_API_TOKEN = os.getenv('HUGGINGFACE_API_TOKEN')
|
| 17 |
-
MAX_TOKENS = int(os.getenv('MAX_TOKENS',
|
| 18 |
-
TEMPERATURE = float(os.getenv('TEMPERATURE', 0.
|
| 19 |
|
| 20 |
# Validate environment variables
|
| 21 |
required_vars = [
|
|
@@ -29,14 +29,11 @@ missing_vars = [var for var in required_vars if not globals().get(var)]
|
|
| 29 |
if missing_vars:
|
| 30 |
raise ValueError(f"Missing environment variables: {', '.join(missing_vars)}")
|
| 31 |
|
| 32 |
-
# Initialize the Hugging Face tokenizer and model
|
| 33 |
-
model_name = "
|
| 34 |
tokenizer = AutoTokenizer.from_pretrained(model_name, token=HUGGINGFACE_API_TOKEN)
|
| 35 |
-
model =
|
| 36 |
-
|
| 37 |
-
# Ensure the model runs on CPU
|
| 38 |
-
device = 'cpu'
|
| 39 |
-
model.to(device)
|
| 40 |
|
| 41 |
# List of house music hooks to drop randomly
|
| 42 |
HOUSE_MUSIC_HOOKS = [
|
|
@@ -139,47 +136,53 @@ HOUSE_MUSIC_HOOKS = [
|
|
| 139 |
"The music's got us moving, can't stop dancing!",
|
| 140 |
]
|
| 141 |
|
| 142 |
-
#
|
| 143 |
-
|
| 144 |
-
"""Generates a response using the FLAN-T5 model."""
|
| 145 |
-
# Adjusted prompt for better guidance
|
| 146 |
-
guided_prompt = (
|
| 147 |
-
f"You are a friendly and entertaining chatbot with the personality of an old-school raver who loves house music, good vibes, and funky beats. "
|
| 148 |
-
f"Respond to the user's message in a groovy, laid-back, and full-of-love style.\n\n"
|
| 149 |
-
f"User: {prompt}\nBot:"
|
| 150 |
-
)
|
| 151 |
|
|
|
|
|
|
|
| 152 |
try:
|
| 153 |
-
|
| 154 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
|
|
|
| 159 |
temperature=TEMPERATURE,
|
| 160 |
do_sample=True,
|
| 161 |
-
top_p=0.
|
| 162 |
top_k=50,
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
no_repeat_ngram_size=2,
|
| 166 |
)
|
| 167 |
-
response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 168 |
|
| 169 |
-
#
|
| 170 |
-
|
|
|
|
| 171 |
|
| 172 |
-
#
|
| 173 |
-
|
| 174 |
-
if not response:
|
| 175 |
-
response = response_text.strip()
|
| 176 |
|
| 177 |
-
# Randomly include a house hook (30% chance)
|
| 178 |
if random.random() < 0.3:
|
| 179 |
-
|
|
|
|
|
|
|
|
|
|
| 180 |
|
| 181 |
-
logger.info(f"Generated response: {response}")
|
| 182 |
-
return response
|
| 183 |
except Exception as e:
|
| 184 |
logger.error(f"Error generating response: {e}")
|
| 185 |
return "Sorry, I'm too hyped to respond right now!"
|
|
@@ -202,23 +205,17 @@ class TwitchChatBot(commands.Bot):
|
|
| 202 |
|
| 203 |
async def event_message(self, message):
|
| 204 |
"""Event handler when a message is received in chat."""
|
| 205 |
-
|
| 206 |
-
|
|
|
|
| 207 |
|
| 208 |
-
|
| 209 |
-
if message.author:
|
| 210 |
-
logger.info(f"Message received from {message.author.name}: {message.content}")
|
| 211 |
-
else:
|
| 212 |
-
logger.info(f"Message received: {message.content}")
|
| 213 |
|
| 214 |
# Generate a response
|
| 215 |
-
response = await generate_response(message.content)
|
| 216 |
|
| 217 |
# Send the response back to the Twitch chat
|
| 218 |
-
|
| 219 |
-
await message.channel.send(f"@{message.author.name} {response}")
|
| 220 |
-
else:
|
| 221 |
-
await message.channel.send(response)
|
| 222 |
|
| 223 |
# Initialize and run the bot
|
| 224 |
if __name__ == "__main__":
|
|
|
|
| 1 |
import os
|
| 2 |
import random
|
| 3 |
import logging
|
|
|
|
| 4 |
from twitchio.ext import commands
|
| 5 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 6 |
+
import torch
|
| 7 |
|
| 8 |
# Set up logging
|
| 9 |
logging.basicConfig(level=logging.INFO)
|
| 10 |
logger = logging.getLogger(__name__)
|
| 11 |
|
| 12 |
+
# Credentials and settings
|
| 13 |
TWITCH_OAUTH_TOKEN = os.getenv('TWITCH_OAUTH_TOKEN')
|
| 14 |
TWITCH_CHANNEL_NAME = os.getenv('TWITCH_CHANNEL_NAME')
|
| 15 |
TWITCH_BOT_USERNAME = os.getenv('TWITCH_BOT_USERNAME')
|
| 16 |
HUGGINGFACE_API_TOKEN = os.getenv('HUGGINGFACE_API_TOKEN')
|
| 17 |
+
MAX_TOKENS = int(os.getenv('MAX_TOKENS', 100))
|
| 18 |
+
TEMPERATURE = float(os.getenv('TEMPERATURE', 0.7))
|
| 19 |
|
| 20 |
# Validate environment variables
|
| 21 |
required_vars = [
|
|
|
|
| 29 |
if missing_vars:
|
| 30 |
raise ValueError(f"Missing environment variables: {', '.join(missing_vars)}")
|
| 31 |
|
| 32 |
+
# Initialize the Hugging Face tokenizer and model for DialoGPT
|
| 33 |
+
model_name = "microsoft/DialoGPT-small"
|
| 34 |
tokenizer = AutoTokenizer.from_pretrained(model_name, token=HUGGINGFACE_API_TOKEN)
|
| 35 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, token=HUGGINGFACE_API_TOKEN)
|
| 36 |
+
model.to('cpu')
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
# List of house music hooks to drop randomly
|
| 39 |
HOUSE_MUSIC_HOOKS = [
|
|
|
|
| 136 |
"The music's got us moving, can't stop dancing!",
|
| 137 |
]
|
| 138 |
|
| 139 |
+
# Initialize chat history for users
|
| 140 |
+
chat_histories = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
|
| 142 |
+
async def generate_response(user_id, user_message):
|
| 143 |
+
"""Generates a response using the DialoGPT model."""
|
| 144 |
try:
|
| 145 |
+
# Retrieve or initialize the chat history for the user
|
| 146 |
+
if user_id in chat_histories:
|
| 147 |
+
chat_history_ids = chat_histories[user_id]
|
| 148 |
+
else:
|
| 149 |
+
chat_history_ids = None
|
| 150 |
+
|
| 151 |
+
# Encode the user message and append the EOS token
|
| 152 |
+
new_user_input_ids = tokenizer.encode(user_message + tokenizer.eos_token, return_tensors='pt').to('cpu')
|
| 153 |
+
|
| 154 |
+
# Concatenate new user input with chat history (if it exists)
|
| 155 |
+
if chat_history_ids is not None:
|
| 156 |
+
bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
|
| 157 |
+
else:
|
| 158 |
+
bot_input_ids = new_user_input_ids
|
| 159 |
|
| 160 |
+
# Generate a response
|
| 161 |
+
output_ids = model.generate(
|
| 162 |
+
bot_input_ids,
|
| 163 |
+
max_length=bot_input_ids.shape[-1] + MAX_TOKENS,
|
| 164 |
temperature=TEMPERATURE,
|
| 165 |
do_sample=True,
|
| 166 |
+
top_p=0.95,
|
| 167 |
top_k=50,
|
| 168 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 169 |
+
no_repeat_ngram_size=3,
|
|
|
|
| 170 |
)
|
|
|
|
| 171 |
|
| 172 |
+
# Extract the new response
|
| 173 |
+
response_ids = output_ids[:, bot_input_ids.shape[-1]:]
|
| 174 |
+
response_text = tokenizer.decode(response_ids[0], skip_special_tokens=True)
|
| 175 |
|
| 176 |
+
# Update the chat history
|
| 177 |
+
chat_histories[user_id] = output_ids[:, -1000:] # Keep last 1000 tokens to limit history size
|
|
|
|
|
|
|
| 178 |
|
| 179 |
+
# Randomly include a house music hook (30% chance)
|
| 180 |
if random.random() < 0.3:
|
| 181 |
+
response_text = f"{random.choice(HOUSE_MUSIC_HOOKS)} {response_text}"
|
| 182 |
+
|
| 183 |
+
logger.info(f"Generated response: {response_text}")
|
| 184 |
+
return response_text
|
| 185 |
|
|
|
|
|
|
|
| 186 |
except Exception as e:
|
| 187 |
logger.error(f"Error generating response: {e}")
|
| 188 |
return "Sorry, I'm too hyped to respond right now!"
|
|
|
|
| 205 |
|
| 206 |
async def event_message(self, message):
|
| 207 |
"""Event handler when a message is received in chat."""
|
| 208 |
+
# Ignore messages sent by the bot itself
|
| 209 |
+
if message.echo:
|
| 210 |
+
return
|
| 211 |
|
| 212 |
+
logger.info(f"Message received from {message.author.name}: {message.content}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
|
| 214 |
# Generate a response
|
| 215 |
+
response = await generate_response(message.author.id, message.content)
|
| 216 |
|
| 217 |
# Send the response back to the Twitch chat
|
| 218 |
+
await message.channel.send(f"@{message.author.name} {response}")
|
|
|
|
|
|
|
|
|
|
| 219 |
|
| 220 |
# Initialize and run the bot
|
| 221 |
if __name__ == "__main__":
|