import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel, PeftConfig from pymongo import MongoClient import torch import os import uuid import json import re import random import hashlib from typing import List, Tuple, Dict, Any # ------------------------------ # βš™οΈ Global Config # ------------------------------ DEVICE = "cuda" if torch.cuda.is_available() else "cpu" MODEL_NAME = "bhushanrocks/supportpal-dialoGPT-v3" # ------------------------------ # πŸ” MongoDB Setup & Local Fallback # ------------------------------ MONGO_URI = os.getenv("MONGO_URI") use_mongo = False db = None users_collection = None chats_collection = None if MONGO_URI: try: client = MongoClient(MONGO_URI, serverSelectionTimeoutMS=5000) client.server_info() db = client["supportpal_db"] users_collection = db["users"] chats_collection = db["chats"] use_mongo = True print("βœ… Connected to MongoDB for persistence and authentication.") except Exception as e: print(f"⚠️ MongoDB unavailable: {e}. Using local JSON storage.") else: print("⚠️ MONGO_URI not set. Using local JSON storage.") LOCAL_USERS_FILE = "local_users.json" LOCAL_CHATS_FILE = "local_chats.json" def load_local_data(filepath): if os.path.exists(filepath): with open(filepath, "r") as f: return json.load(f) return {} def save_local_data(data, filepath): with open(filepath, "w") as f: json.dump(data, f, indent=4) local_users = load_local_data(LOCAL_USERS_FILE) local_chats = load_local_data(LOCAL_CHATS_FILE) def hash_password(password): return hashlib.sha256(password.encode()).hexdigest() # ------------------------------ # πŸ”‘ Auth Functions # ------------------------------ def authenticate_user(username, password): hashed = hash_password(password) if use_mongo: user_data = users_collection.find_one({"username": username}) else: user_data = local_users.get(username) if user_data and user_data["password"] == hashed: return user_data["user_id"], f"Welcome back, {username}!" return None, "Invalid username or password." def register_user(username, password): hashed = hash_password(password) new_user_id = str(uuid.uuid4()) if use_mongo: if users_collection.find_one({"username": username}): return None, "Username already taken." users_collection.insert_one({"username": username, "password": hashed, "user_id": new_user_id}) else: if username in local_users: return None, "Username already taken." local_users[username] = {"username": username, "password": hashed, "user_id": new_user_id} save_local_data(local_users, LOCAL_USERS_FILE) return new_user_id, f"Account created! Welcome, {username}." # ------------------------------ # πŸ’Ύ Chat History Handling # ------------------------------ def get_history(user_id): if not user_id: return [] if use_mongo: convo = chats_collection.find_one({"user_id": user_id}) return convo.get("history", []) return local_chats.get(user_id, []) def save_history(user_id, history): if not user_id: return if use_mongo: chats_collection.update_one({"user_id": user_id}, {"$set": {"history": history}}, upsert=True) else: local_chats[user_id] = history save_local_data(local_chats, LOCAL_CHATS_FILE) def history_to_gradio_format(history: List[Tuple[str, str]]): gradio_msgs = [] for user_msg, bot_msg in history: if user_msg: gradio_msgs.append({"role": "user", "content": user_msg}) if bot_msg: gradio_msgs.append({"role": "assistant", "content": bot_msg}) return gradio_msgs # ------------------------------ # πŸ€— Load Model # ------------------------------ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) try: config = PeftConfig.from_pretrained(MODEL_NAME) base_model = AutoModelForCausalLM.from_pretrained( config.base_model_name_or_path, torch_dtype=torch.float16 if DEVICE == "cuda" else None, ).to(DEVICE) model = PeftModel.from_pretrained(base_model, MODEL_NAME) print(f"βœ… Loaded PEFT model from {MODEL_NAME}") except Exception as e: print(f"⚠️ Failed to load PEFT adapters, fallback: {e}") model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE) model.eval() # ------------------------------ # πŸ’¬ Chat Logic # ------------------------------ def generate_response(user_input: str, state_dict: dict): """Main generation logic.""" if not state_dict.get("logged_in"): return [], state_dict, "Please log in to start chatting." user_id = state_dict["user_id"] history = state_dict.get("history", []) # Define system prompt here system_prompt = ( "You are SupportPal β€” a kind, compassionate AI that listens patiently, " "validates emotions, and gently helps people feel understood. " "Avoid talking about yourself, keep responses concise, warm, and natural." ) # Keep only the last 5 turns context_turns = history[-5:] # Build dialogue text prompt_turns = [] for turn in context_turns: if isinstance(turn, dict): prompt_turns.append(f"Human: {turn.get('user','')}\nAI: {turn.get('bot','')}") elif isinstance(turn, (list, tuple)) and len(turn) == 2: prompt_turns.append(f"Human: {turn[0]}\nAI: {turn[1]}") dialogue = system_prompt.strip() + "\n\n" + "\n".join(prompt_turns) dialogue += f"\nHuman: {user_input}\nAI:" # Tokenize and generate inputs = tokenizer(dialogue, return_tensors="pt", truncation=True, max_length=512).to(model.device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=200, temperature=0.7, top_k=50, top_p=0.9, do_sample=True, pad_token_id=tokenizer.eos_token_id, ) decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) response = decoded.split("AI:")[-1].strip() # Clean and finalize response = re.sub(r"You are SupportPal.*", "", response, flags=re.DOTALL).strip() response = re.sub(r"\b(I am|I'm) glad\b.*?\.", "", response) response = re.sub(r"\b(I hope|I wish)\b.*?\.", "", response) response = response.replace("AI:", "").strip() if not response: response = "I'm here with you. Tell me more about what’s on your mind. πŸ’›" if random.random() < 0.25 and not any(x in response for x in ["πŸ’›", "✨"]): response += " " + random.choice(["πŸ’›", "✨", "You’re not alone in this."]) # Save and return history.append((user_input, response)) save_history(user_id, history) state_dict["history"] = history gradio_history = history_to_gradio_format(history) return gradio_history, state_dict, "" def process_chat(user_message, state_dict): """Wrapper used for Gradio callbacks.""" if not state_dict: state_dict = {"logged_in": False, "user_id": None, "username": None, "history": []} if not state_dict.get("logged_in"): return [], state_dict, "Please log in to start chatting." return generate_response(user_message, state_dict) # ------------------------------ # πŸ” Login / Register / Logout # ------------------------------ def handle_login(username, password, state): user_id, status = authenticate_user(username, password) if user_id: history = get_history(user_id) state.update({"logged_in": True, "user_id": user_id, "username": username, "history": history}) gr_history = history_to_gradio_format(history) return status, gr.update(visible=False), gr.update(visible=True), gr_history, state, f"Logged in as: **{username}**" return status, gr.update(visible=True), gr.update(visible=False), [], state, gr.update() def handle_register(username, password, state): user_id, status = register_user(username, password) if user_id: state.update({"logged_in": True, "user_id": user_id, "username": username, "history": []}) return status, gr.update(visible=False), gr.update(visible=True), [], state, f"Logged in as: **{username}**" return status, gr.update(visible=True), gr.update(visible=False), [], state, gr.update() def handle_logout(): return ( "Logged out successfully.", gr.update(visible=True), gr.update(visible=False), [], {"logged_in": False, "user_id": None, "username": None, "history": []}, gr.update(value=""), "Logged in as:", ) def handle_clear_chat(state): if state.get("logged_in"): user_id = state["user_id"] save_history(user_id, []) state["history"] = [] return [], state return [], state # ------------------------------ # 🧱 Gradio UI # ------------------------------ with gr.Blocks(title="SupportPal Chatbot") as demo: storage_mode = "☁️ MongoDB Connected" if use_mongo else "πŸ’Ύ Local Storage" session_state = gr.State(value={"logged_in": False, "user_id": None, "username": None, "history": []}) gr.Markdown(f"## πŸ€– SupportPal β€” Empathetic Chatbot\n**Persistence Mode:** {storage_mode}") # Auth Panel with gr.Row(visible=True) as auth_panel: with gr.Column(): auth_status = gr.Markdown("Please log in or register to start chatting.") username_box = gr.Textbox(label="Username") password_box = gr.Textbox(label="Password", type="password") with gr.Row(): login_btn = gr.Button("Login", variant="primary") register_btn = gr.Button("Register", variant="secondary") # Chat Panel with gr.Column(visible=False) as chat_panel: user_status = gr.Markdown("Logged in as:") chatbot_ui = gr.Chatbot(height=400, type="messages") with gr.Row(): msg = gr.Textbox(label="Your message...", scale=4) send_btn = gr.Button("Send", variant="primary", scale=1) with gr.Row(): clear_btn = gr.Button("Clear Chat", variant="secondary") logout_btn = gr.Button("Logout", variant="stop") # Bind events login_btn.click(handle_login, [username_box, password_box, session_state], [auth_status, auth_panel, chat_panel, chatbot_ui, session_state, user_status]) register_btn.click(handle_register, [username_box, password_box, session_state], [auth_status, auth_panel, chat_panel, chatbot_ui, session_state, user_status]) send_btn.click(process_chat, [msg, session_state], [chatbot_ui, session_state, auth_status]).then(lambda: gr.update(value=""), None, [msg]) msg.submit(process_chat, [msg, session_state], [chatbot_ui, session_state, auth_status]).then(lambda: gr.update(value=""), None, [msg]) clear_btn.click(handle_clear_chat, [session_state], [chatbot_ui, session_state]) logout_btn.click(handle_logout, None, [auth_status, auth_panel, chat_panel, chatbot_ui, session_state, msg, user_status]) demo.launch()