import os import torch import logging from telegram import Update from telegram.ext import ApplicationBuilder, CommandHandler, MessageHandler, ContextTypes import telegram.ext.filters as filters from transformers import AutoModelForCausalLM, AutoTokenizer # Set up logging logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO) logger = logging.getLogger(__name__) # Load the model and tokenizer model_name = "tanusrich/Mental_Health_Chatbot" tokenizer = AutoTokenizer.from_pretrained(model_name) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = AutoModelForCausalLM.from_pretrained(model_name).to(device) # Function to read the token from the token.txt file def read_token(): with open('my_support_bot_token.txt', 'r') as file: return file.read().strip() # Removes any leading/trailing whitespace # Function to format the prompt for the model def format_prompt(prompt, chat_history): history = "".join([f"User: {entry['user']}\nAI: {entry['ai']}\n" for entry in chat_history]) return f"[INST] <> You are a virtual AI therapy assistant. Your role is to provide thoughtful and supportive responses. Always ensure that you complete your last sentence with a period.<> {history}User: {prompt.strip()} [/INST]" # Function to clean the output from the model def clean_output(output_text, input_text): output_text = output_text.replace(input_text, "") output_text = output_text.replace("[INST]", "").replace("[/INST]", "").replace("(period)", "").replace("(Period)", "") return output_text.strip() async def start(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: await update.message.reply_text("Hello! I'm here to listen and support you. How can I help today?") async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: user_message = update.message.text # Prepare prompt for model input formatted_prompt = format_prompt(user_message, []) # Tokenize input and generate response inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device) with torch.no_grad(): output = model.generate( **inputs, max_new_tokens=200, temperature=0.7, top_k=50, top_p=0.9, repetition_penalty=1.2, pad_token_id=tokenizer.eos_token_id ) response = tokenizer.decode(output[0], skip_special_tokens=True) logger.debug("here is the llm response: %s" % response) # Clean response and send it back to the user clean_response = clean_output(response, formatted_prompt) await update.message.reply_text(clean_response) def main(): token = read_token() # Create an Application instance application = ApplicationBuilder().token(token).build() # Register command handlers application.add_handler(CommandHandler("start", start)) application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, handle_message)) # Start the Bot application.run_polling() if __name__ == '__main__': main()