Gemma_Chatbot / app.py
AvocadoMuffin's picture
Update app.py
f239ec4 verified
import gradio as gr
#from gradio.components import Box
import torch
import os
import faiss
import pickle
from transformers import AutoModelForCausalLM, AutoTokenizer
from sentence_transformers import SentenceTransformer, util
from huggingface_hub import hf_hub_download
import pandas as pd
from datasets import Dataset
from huggingface_hub import create_repo
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
# ๐Ÿ”น Ensure bitsandbytes installation (for efficient model loading)
try:
if torch.cuda.is_available():
import bitsandbytes as bnb # Use the GPU version if CUDA is available
print("Using bitsandbytes with GPU support.")
else:
import bitsandbytes # Fallback to CPU version (as bitsandbytes supports both)
print("Using bitsandbytes with CPU support.")
except ImportError:
print("bitsandbytes not found, falling back to CPU.")
# ๐Ÿ”น Load Hugging Face Model Repo Details
MODEL_NAME = "AvocadoMuffin/Gemma_Fine_Tuned_Model"
HF_TOKEN = os.getenv("HF_TOKEN") # Ensure this is set in your environment
# ๐Ÿ”น Load Tokenizer & Model
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=dtype, token=HF_TOKEN).to(device)
tokenizer.pad_token = tokenizer.eos_token
# ๐Ÿ”น Load FAISS Index & Knowledge from HF Model Repo
faiss_index_path = hf_hub_download(repo_id=MODEL_NAME, filename="faiss_index/index.faiss", token=HF_TOKEN)
knowledge_path = hf_hub_download(repo_id=MODEL_NAME, filename="faiss_index/knowledge.pkl", token=HF_TOKEN)
# Load FAISS Index
index = faiss.read_index(faiss_index_path)
# Load Knowledge.pkl (Contains text data + embeddings)
with open(knowledge_path, "rb") as f:
knowledge_data = pickle.load(f)
# Validate knowledge data
if isinstance(knowledge_data, dict) and "questions" in knowledge_data and "answers" in knowledge_data:
questions = knowledge_data["questions"]
answers = knowledge_data["answers"]
else:
raise ValueError("โŒ knowledge.pkl does not contain expected questions and answers keys!")
# ๐Ÿ”น Load Sentence Transformer Embedding Model
embedding_model = SentenceTransformer("multi-qa-mpnet-base-dot-v1")
print("โœ… FAISS Index, Knowledge Data, and Embedding Model Loaded Successfully!")
# ๐Ÿ”น Function: Retrieve Relevant Context from FAISS
def retrieve_relevant_context(query, index, answers, embedding_model, top_k=3, similarity_threshold=0.8):
"""Retrieve top-k relevant knowledge snippets from FAISS index."""
query_embedding = embedding_model.encode([query], convert_to_numpy=True)
distances, indices = index.search(query_embedding, top_k)
retrieved_contexts = []
seen_embeddings = []
for idx in indices[0]:
if 0 <= idx < len(answers):
answer = answers[idx]
answer_embedding = embedding_model.encode([answer], convert_to_numpy=True)
if all(util.cos_sim(answer_embedding, emb) < similarity_threshold for emb in seen_embeddings):
retrieved_contexts.append(answer)
seen_embeddings.append(answer_embedding)
return "\n".join(retrieved_contexts[:top_k])
# Define system messages for different modes
SYSTEM_MESSAGES = {
"general": "You are a helpful AI assistant for everyday tasks.",
"mental_health": "You are a helpful AI assistant for mental health."
}
# Mental health keywords to check if query is related to mental health
MENTAL_HEALTH_KEYWORDS = [
"anxiety", "depression", "stress", "therapy", "counseling", "mental health", "panic", "trauma",
"mindfulness", "meditation", "psychological", "self-care", "emotion", "feeling", "mood",
"psychiatry", "therapist", "psychologist", "burnout", "coping", "ptsd", "ocd", "adhd", "bipolar",
"schizophrenia", "grief", "insomnia", "wellbeing", "mental wellbeing", "mental wellness",
"emotional wellbeing", "emotional wellness", "self-help", "relaxation", "breathing exercise",
"cognitive behavioral", "cbt", "sad", "angry", "worried", "anxious", "depressed", "overwhelmed",
"psychology", "mental disorder", "suicide", "suicidal", "crisis", "loneliness", "lonely", "isolated",
"phobia", "trauma", "addiction", "substance abuse", "eating disorder", "self-esteem", "confidence"
]
# Function to check if a query is related to mental health
def is_mental_health_query(query):
query_lower = query.lower()
return any(keyword in query_lower for keyword in MENTAL_HEALTH_KEYWORDS)
# ๐Ÿ”น Function: Generate Response Using RAG
def generate_rag_response(query, model, tokenizer, index, answers, embedding_model, chat_history, system_message, max_new_tokens=200):
"""Generate response using retrieved context, chat history, and system message."""
# Check if query is mental health related when in mental health mode
if system_message == SYSTEM_MESSAGES["mental_health"] and not is_mental_health_query(query):
return "I'm currently in mental health assistant mode and can only answer questions related to mental health, psychology, emotions, and wellbeing. If you'd like to discuss other topics, please change the system message to general assistant mode."
query_lower = query.strip().lower()
# Define follow-up triggers
follow_up_triggers = ["continue", "go on", "tell me more", "what else?", "elaborate", "explain more"]
is_follow_up = any(trigger in query_lower for trigger in follow_up_triggers)
# Retrieve relevant knowledge
retrieved_context = retrieve_relevant_context(query, index, answers, embedding_model).strip()
# Store last 5 exchanges in chat history (limit it)
chat_context = "\n".join([msg["content"] for msg in chat_history[-5:]])
# Handle Follow-up Responses More Effectively
if is_follow_up and chat_history:
last_ai_response = chat_history[-1] # Get last AI message
query = f"{last_ai_response['content']} {query}" # Append the follow-up to continue smoothly
# Few-Shot Examples for Mental Health mode only
few_shot_examples = """
User: What are some effective meditation techniques for beginners?
AI: Based on established research, here are some effective meditation techniques for beginners:
1. Mindful Breathing - Focus on your breath, noticing the sensation of air moving in and out of your body.
2. Body Scan Meditation - Progressively relax different parts of your body from head to toe.
3. Loving-Kindness Meditation - Direct positive wishes toward yourself and others.
4. Guided Meditation - Follow along with recorded instructions to help maintain focus.
5. Walking Meditation - Practice mindfulness while walking slowly and deliberately.
I recommend starting with just 5 minutes daily and gradually increasing duration. Would you like more details about any specific technique, or would you prefer suggestions for apps that could help you get started?
User: Tell me more about body scan meditation.
AI: Body scan meditation is a mindfulness practice where you systematically focus your attention on different parts of your body, from your toes to the top of your head.
Here's how to practice it:
1. Lie down in a comfortable position or sit in a chair with your feet on the ground
2. Close your eyes and take several deep breaths
3. Begin by bringing awareness to your toes, noticing any sensations (tingling, temperature, tension)
4. Gradually move your attention upward (feet, ankles, calves, etc.)
5. For each body part, notice sensations without judgment, then consciously relax that area
6. If you notice pain or discomfort, acknowledge it with compassion before moving on
7. Continue until you've scanned your entire body
Regular practice can help reduce stress, improve body awareness, and release physical tension. Many find it helpful for sleep issues as well.
Would you like guidance on how long to practice or how to handle common challenges like mind-wandering during body scans?
"""
# Build prompt depending on system message
if system_message == SYSTEM_MESSAGES["mental_health"]:
prompt = f"""You are a compassionate mental wellness assistant powered by a retrieval-augmented generation (RAG) system.
Maintain a warm, supportive conversation while providing evidence-based mental health information.
SYSTEM MESSAGE: {system_message}
When responding to mental health queries:
- Provide evidence-based information from reputable sources
- Keep the tone supportive, non-judgmental, and empathetic
- If uncertain about mental health information, acknowledge limitations and avoid potentially harmful advice
- For crisis situations, recommend professional help rather than providing advice
- Frame mental wellness information in a way that emphasizes self-compassion and gradual progress
If the user expresses overwhelm, anxiety, or difficulty with their mental health, offer evidence-based coping strategies and gentle support.
{few_shot_examples if len(chat_history) > 0 else ""}
"""
else:
prompt = f"""You are a helpful AI assistant designed to answer questions on a wide range of topics.
Provide accurate, concise information while maintaining a conversational tone.
SYSTEM MESSAGE: {system_message}
When responding to queries:
- Provide accurate information on the requested topic
- Keep explanations clear and easy to understand
- Be conversational and friendly
- If you don't know something, acknowledge limitations rather than making up information
"""
# Only append the relevant context if available
if retrieved_context:
prompt += f"\n๐Ÿ“Œ Relevant Information Retrieved:\n{retrieved_context}\n\n"
prompt += f"๐Ÿ’ฌ Chat History:\n{chat_context}\n๐Ÿ‘ค User: {query}\n๐Ÿค– AI:"
# Tokenize input
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
# Generate AI response
with torch.no_grad():
output_ids = model.generate(
input_ids,
max_new_tokens=max_new_tokens,
eos_token_id=tokenizer.eos_token_id,
)
response = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
# Prevent prompt repetition in response
if response.startswith(prompt):
response = response[len(prompt):].strip()
return response
# ๐Ÿ”น Function to save feedback for RLHF
from huggingface_hub import create_repo, HfApi
from datasets import Dataset, load_dataset
import pandas as pd
import os
import json
def save_feedback(query, response, feedback, comment=""):
feedback_file = "chat_feedback.csv"
repo_name = "AvocadoMuffin/Chat_Feedback_Data"
# Create new data from current feedback
new_data = pd.DataFrame([[query, response, feedback, comment]],
columns=["Query", "Response_Generated", "Feedback", "Comment"])
try:
# First check if we have a local copy of the data
if os.path.exists(feedback_file):
existing_data = pd.read_csv(feedback_file)
print(f"Loaded existing data from local file with {len(existing_data)} records")
else:
# Try to download the dataset from HF Hub
api = HfApi()
try:
# List all files in the repository to find our data
files = api.list_repo_files(repo_id=repo_name, repo_type="dataset")
# Look for parquet files which contain our data
data_files = [f for f in files if f.endswith('.parquet') or f.endswith('.csv')]
if data_files:
# Download the first data file we find
file_path = api.hf_hub_download(
repo_id=repo_name,
filename=data_files[0],
repo_type="dataset"
)
# Load the data based on file extension
if file_path.endswith('.parquet'):
existing_data = pd.read_parquet(file_path)
else:
existing_data = pd.read_csv(file_path)
print(f"Successfully loaded data from Hub with {len(existing_data)} records")
else:
print("No data files found in repository")
existing_data = pd.DataFrame(columns=["Query", "Response_Generated", "Feedback", "Comment"])
except Exception as e:
print(f"Error accessing repository: {e}")
existing_data = pd.DataFrame(columns=["Query", "Response_Generated", "Feedback", "Comment"])
# Append new feedback to the existing data
updated_data = pd.concat([existing_data, new_data], ignore_index=True)
print(f"Updated dataset now has {len(updated_data)} records")
# Always save locally for backup and future reference
updated_data.to_csv(feedback_file, index=False)
print(f"Saved updated data to local file with {len(updated_data)} records")
# Convert to Hugging Face Dataset
dataset = Dataset.from_pandas(updated_data)
# Create repo if needed and push to Hugging Face Hub
create_repo(repo_id=repo_name, repo_type="dataset", exist_ok=True)
dataset.push_to_hub(repo_name)
print("Feedback saved and dataset updated on Hugging Face!")
except Exception as e:
print(f"Error updating dataset: {e}")
# Still save locally as backup
if os.path.exists(feedback_file):
local_data = pd.read_csv(feedback_file)
pd.concat([local_data, new_data], ignore_index=True).to_csv(feedback_file, index=False)
else:
new_data.to_csv(feedback_file, index=False)
print("Saved locally only due to error")
return "Feedback saved and dataset updated!"
# ๐Ÿ”น The chatbot function (with RAG and conversation history tracking)
def chatbot(query, chat_history=None, system_message=SYSTEM_MESSAGES["general"]):
if chat_history is None:
chat_history = [] # Initialize chat_history as an empty list
# Append user message to the chat history
chat_history.append({"role": "user", "content": query})
# Generate response using the RAG model with specified system message
response = generate_rag_response(query, model, tokenizer, index, answers, embedding_model, chat_history, system_message)
# Append the assistant's response to the chat history for context
chat_history.append({"role": "assistant", "content": response})
# Format chat history with HTML and CSS for better visual appeal
formatted_chat_history = ""
for msg in chat_history:
if msg["role"] == "user":
formatted_chat_history += f'<p style="color: #4A90E2; font-weight: bold;">User:</p><p style="background-color: #E7F1F8; padding: 10px; border-radius: 5px; font-family: Arial, sans-serif; font-size: 14px;">{msg["content"]}</p>'
else:
formatted_chat_history += f'<p style="color: #1D4ED8; font-weight: bold;">AI:</p><p style="background-color: #D1F4D9; padding: 10px; border-radius: 5px; font-family: Arial, sans-serif; font-size: 14px;">{msg["content"]}</p>'
# Return the formatted chat history as HTML and the updated chat history
# Make sure to reset the feedback radio button value to None when showing new feedback options
return gr.HTML(formatted_chat_history), chat_history, gr.update(visible=True, value=None), gr.update(visible=True), gr.update(visible=True, value=""), response, gr.update(value=""), query
# Function to handle feedback submission
def handle_feedback(query, latest_response, feedback, comment):
# Store feedback as None if not selected
if feedback: # Only save if feedback has a value
feedback_value = 1 if feedback == "๐Ÿ‘ Like" else 0
save_feedback(query, latest_response, feedback_value, comment) # Save feedback data
else:
# If no feedback is selected, save as NULL or handle accordingly
save_feedback(query, latest_response, None, comment)
# Hide feedback UI after submission and clear both the comment field and feedback selection
return gr.update(visible=False, value=None), gr.update(visible=False), gr.update(visible=False, value="")
# Function to handle system message change and reset chat
def change_system_message(message_key, chat_history_state):
# Get the actual system message
selected_message = SYSTEM_MESSAGES[message_key]
# Reset chat history when mode changes
new_chat_history = []
# Create welcome message based on selected mode
if message_key == "mental_health":
welcome_message = "I'm now in mental health assistant mode. I can help with questions about mental health, wellness, emotions, and psychological well-being. How can I support you today?"
else:
welcome_message = "I'm here to help with a wide range of questions and tasks. What would you like assistance with today?"
# Add welcome message to chat history
new_chat_history.append({"role": "assistant", "content": welcome_message})
# Format the welcome message for display
formatted_chat_history = f'<p style="color: #1D4ED8; font-weight: bold;">AI:</p><p style="background-color: #D1F4D9; padding: 10px; border-radius: 5px; font-family: Arial, sans-serif; font-size: 14px;">{welcome_message}</p>'
return gr.HTML(formatted_chat_history), new_chat_history, selected_message
def create_gradio_interface():
with gr.Blocks() as demo:
gr.Markdown("## **๐Ÿง  Mindful Bot: AI for Everything**")
gr.Markdown("An AI chatbot designed to offer empathetic support for mental wellness using RAG-based techniques, capable of handling adversarial queries with care. It also provides insightful responses to general queries, separating wellness-related conversations for more focused, compassionate assistance.")
latest_query = gr.State("")
latest_response = gr.State("")
chat_history_state = gr.State([])
current_system_message = gr.State(SYSTEM_MESSAGES["general"])
with gr.Row():
with gr.Column(scale=7): # Chat area
# Use gr.HTML to display the formatted chat history
chat_history_box = gr.HTML(label="Chat History", elem_id="chat-history")
with gr.Row():
with gr.Column(scale=8):
query = gr.Textbox(label="Enter your query", elem_id="user-input", lines=3)
with gr.Column(scale=2):
submit_button = gr.Button("Submit", elem_id="submit-button", variant="primary")
# Feedback UI - make it optional with no default selection
feedback = gr.Radio(choices=["๐Ÿ‘ Like", "๐Ÿ‘Ž Dislike"], label="Feedback (Optional)", type="value", visible=False)
submit_feedback_button = gr.Button("Submit Feedback", visible=False)
feedback_comments = gr.Textbox(label="Additional Comments (optional)", lines=2, visible=False)
with gr.Column(scale=3): # System message area
with gr.Group():
gr.HTML("<h2 style='text-align: center; color: white; background-color: orange; padding: 10px;'>System message</h2>")
# Example table with system messages
example_table = gr.Dataframe(
headers=["Message", "System message"],
datatype=["str", "str"],
value=[
["Hello! How are you?", SYSTEM_MESSAGES["mental_health"]],
["Can you help with a recipe for baking a cake?", SYSTEM_MESSAGES["general"]],
],
row_count=2,
col_count=2,
interactive=True,
#height=150
)
# System message selector
system_selector = gr.Radio(
choices=["general", "mental_health"],
value="general",
label="Select system message type",
elem_id="system-selector",
info="Choose the type of assistant"
)
# Display the current system message
system_message_display = gr.Textbox(
value=SYSTEM_MESSAGES["general"],
label="Current system instruction:",
interactive=False
)
# Reset button
reset_button = gr.Button("Reset Chat", elem_id="reset-button")
# Additional inputs section
#gr.Markdown("### Additional Inputs")
# Add custom CSS with HTML component
gr.HTML("""<style>
#chat-history { background-color: #f5f5f5; border-radius: 10px; padding: 10px; max-height: 400px; overflow-y: scroll; font-family: 'Arial', sans-serif; }
#user-input { margin-top: 10px; border-radius: 10px; }
#submit-button { margin-top: 10px; }
#system-selector { margin-top: 10px; }
</style>""")
# Event handlers
query.submit(chatbot,
inputs=[query, chat_history_state, current_system_message],
outputs=[chat_history_box, chat_history_state, feedback, submit_feedback_button, feedback_comments, latest_response, query, latest_query])
submit_button.click(chatbot,
inputs=[query, chat_history_state, current_system_message],
outputs=[chat_history_box, chat_history_state, feedback, submit_feedback_button, feedback_comments, latest_response, query, latest_query])
# Modified feedback handling
submit_feedback_button.click(handle_feedback,
inputs=[latest_query, latest_response, feedback, feedback_comments],
outputs=[feedback, submit_feedback_button, feedback_comments])
# System message change handler
system_selector.change(change_system_message,
inputs=[system_selector, chat_history_state],
outputs=[chat_history_box, chat_history_state, current_system_message])
# Also update system message display when selector changes
system_selector.change(lambda x: SYSTEM_MESSAGES[x],
inputs=[system_selector],
outputs=[system_message_display])
# Reset button handler - uses the same function as system_selector.change
reset_button.click(change_system_message,
inputs=[system_selector, chat_history_state],
outputs=[chat_history_box, chat_history_state, current_system_message])
# Also reset system message display with reset button
reset_button.click(lambda x: SYSTEM_MESSAGES[x],
inputs=[system_selector],
outputs=[system_message_display])
# Initialize with welcome message on load
demo.load(change_system_message,
inputs=[system_selector, chat_history_state],
outputs=[chat_history_box, chat_history_state, current_system_message])
# Initialize system message display on load
demo.load(lambda x: SYSTEM_MESSAGES[x],
inputs=[system_selector],
outputs=[system_message_display])
return demo
# Create and launch the Gradio interface
demo = create_gradio_interface()
demo.launch(share=False, debug=True)