Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import logging | |
| from telegram import Update | |
| from telegram.ext import ApplicationBuilder, CommandHandler, MessageHandler, ContextTypes | |
| import telegram.ext.filters as filters | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| # Set up logging | |
| logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Load the model and tokenizer | |
| model_name = "tanusrich/Mental_Health_Chatbot" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = AutoModelForCausalLM.from_pretrained(model_name).to(device) | |
| # Function to read the token from the token.txt file | |
| def read_token(): | |
| with open('my_support_bot_token.txt', 'r') as file: | |
| return file.read().strip() # Removes any leading/trailing whitespace | |
| # Function to format the prompt for the model | |
| def format_prompt(prompt, chat_history): | |
| history = "".join([f"User: {entry['user']}\nAI: {entry['ai']}\n" for entry in chat_history]) | |
| return f"[INST] <<SYS>> You are a virtual AI therapy assistant. Your role is to provide thoughtful and supportive responses. Always ensure that you complete your last sentence with a period.<</SYS>> {history}User: {prompt.strip()} [/INST]" | |
| # Function to clean the output from the model | |
| def clean_output(output_text, input_text): | |
| output_text = output_text.replace(input_text, "") | |
| output_text = output_text.replace("[INST]", "").replace("[/INST]", "").replace("(period)", "").replace("(Period)", "") | |
| return output_text.strip() | |
| async def start(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: | |
| await update.message.reply_text("Hello! I'm here to listen and support you. How can I help today?") | |
| async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: | |
| user_message = update.message.text | |
| # Prepare prompt for model input | |
| formatted_prompt = format_prompt(user_message, []) | |
| # Tokenize input and generate response | |
| inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| output = model.generate( | |
| **inputs, | |
| max_new_tokens=200, | |
| temperature=0.7, | |
| top_k=50, | |
| top_p=0.9, | |
| repetition_penalty=1.2, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| response = tokenizer.decode(output[0], skip_special_tokens=True) | |
| logger.debug("here is the llm response: %s" % response) | |
| # Clean response and send it back to the user | |
| clean_response = clean_output(response, formatted_prompt) | |
| await update.message.reply_text(clean_response) | |
| def main(): | |
| token = read_token() | |
| # Create an Application instance | |
| application = ApplicationBuilder().token(token).build() | |
| # Register command handlers | |
| application.add_handler(CommandHandler("start", start)) | |
| application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, handle_message)) | |
| # Start the Bot | |
| application.run_polling() | |
| if __name__ == '__main__': | |
| main() | |