Spaces:
Sleeping
Sleeping
| import torch | |
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| device = "cuda" if torch.cuda.is_available() else "cpu" # Automatically detect GPU or CPU | |
| model_name = "tanusrich/Mental_Health_Chatbot" | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float16, # Reduce memory usage | |
| device_map="cpu", # Automatically assigns to GPU if available | |
| low_cpu_mem_usage=True, | |
| max_memory={0: "3.5GiB", "cpu": "12GiB"}, # Optimize CPU memory | |
| offload_folder=None | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| ''' | |
| model_save_path = "./model" | |
| # Save model | |
| model.save_pretrained(model_save_path) | |
| # Save tokenizer | |
| tokenizer.save_pretrained(model_save_path)''' | |
| def generate_response(user_input): | |
| inputs = tokenizer(user_input, return_tensors="pt").to("cpu") | |
| with torch.no_grad(): | |
| output = model.generate( | |
| **inputs, | |
| max_new_tokens=150, | |
| temperature=0.7, | |
| top_k=50, | |
| top_p=0.9, | |
| repetition_penalty=1.2, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| response = tokenizer.decode(output[0], skip_special_tokens=True) | |
| # Extract only chatbot's latest response | |
| chatbot_response = response.split("Chatbot:")[-1].strip() | |
| # Update conversation history | |
| conversation_history += chatbot_response + "\n" | |
| return chatbot_response | |
| # Continuous conversation loop | |
| '''while True: | |
| user_input = input("You: ") # Take user input | |
| if user_input.lower() in ["exit", "quit", "stop"]: | |
| print("Chatbot: Goodbye!") | |
| break | |
| response = generate_response(user_input) | |
| print("Chatbot:", response)''' | |
| # Initialize the ChatInterface | |
| chatbot = gr.ChatInterface(fn=generate_response, title="Mental Health Chatbot") | |
| chatbot.launch() | |
| ''' | |
| # Example | |
| user_input = "I'm feeling suicidal." | |
| response = generate_response(user_input) | |
| print("Chatbot: ", response) | |
| ''' | |