Therapy_Chatbot / app.py
NinjaHattori's picture
Update app.py
125d4bb verified
import os
import torch
import logging
from telegram import Update
from telegram.ext import ApplicationBuilder, CommandHandler, MessageHandler, ContextTypes
import telegram.ext.filters as filters
from transformers import AutoModelForCausalLM, AutoTokenizer
# Set up logging
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO)
logger = logging.getLogger(__name__)
# Load the model and tokenizer
model_name = "tanusrich/Mental_Health_Chatbot"
tokenizer = AutoTokenizer.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
# Function to read the token from the token.txt file
def read_token():
with open('my_support_bot_token.txt', 'r') as file:
return file.read().strip() # Removes any leading/trailing whitespace
# Function to format the prompt for the model
def format_prompt(prompt, chat_history):
history = "".join([f"User: {entry['user']}\nAI: {entry['ai']}\n" for entry in chat_history])
return f"[INST] <<SYS>> You are a virtual AI therapy assistant. Your role is to provide thoughtful and supportive responses. Always ensure that you complete your last sentence with a period.<</SYS>> {history}User: {prompt.strip()} [/INST]"
# Function to clean the output from the model
def clean_output(output_text, input_text):
output_text = output_text.replace(input_text, "")
output_text = output_text.replace("[INST]", "").replace("[/INST]", "").replace("(period)", "").replace("(Period)", "")
return output_text.strip()
async def start(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
await update.message.reply_text("Hello! I'm here to listen and support you. How can I help today?")
async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
user_message = update.message.text
# Prepare prompt for model input
formatted_prompt = format_prompt(user_message, [])
# Tokenize input and generate response
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=200,
temperature=0.7,
top_k=50,
top_p=0.9,
repetition_penalty=1.2,
pad_token_id=tokenizer.eos_token_id
)
response = tokenizer.decode(output[0], skip_special_tokens=True)
logger.debug("here is the llm response: %s" % response)
# Clean response and send it back to the user
clean_response = clean_output(response, formatted_prompt)
await update.message.reply_text(clean_response)
def main():
token = read_token()
# Create an Application instance
application = ApplicationBuilder().token(token).build()
# Register command handlers
application.add_handler(CommandHandler("start", start))
application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, handle_message))
# Start the Bot
application.run_polling()
if __name__ == '__main__':
main()