Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from peft import PeftModel, PeftConfig | |
| from pymongo import MongoClient | |
| import torch | |
| import os | |
| import uuid | |
| import json | |
| import re | |
| import random | |
| import hashlib | |
| from typing import List, Tuple, Dict, Any | |
| # ------------------------------ | |
| # βοΈ Global Config | |
| # ------------------------------ | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| MODEL_NAME = "bhushanrocks/supportpal-dialoGPT-v3" | |
| # ------------------------------ | |
| # π MongoDB Setup & Local Fallback | |
| # ------------------------------ | |
| MONGO_URI = os.getenv("MONGO_URI") | |
| use_mongo = False | |
| db = None | |
| users_collection = None | |
| chats_collection = None | |
| if MONGO_URI: | |
| try: | |
| client = MongoClient(MONGO_URI, serverSelectionTimeoutMS=5000) | |
| client.server_info() | |
| db = client["supportpal_db"] | |
| users_collection = db["users"] | |
| chats_collection = db["chats"] | |
| use_mongo = True | |
| print("β Connected to MongoDB for persistence and authentication.") | |
| except Exception as e: | |
| print(f"β οΈ MongoDB unavailable: {e}. Using local JSON storage.") | |
| else: | |
| print("β οΈ MONGO_URI not set. Using local JSON storage.") | |
| LOCAL_USERS_FILE = "local_users.json" | |
| LOCAL_CHATS_FILE = "local_chats.json" | |
| def load_local_data(filepath): | |
| if os.path.exists(filepath): | |
| with open(filepath, "r") as f: | |
| return json.load(f) | |
| return {} | |
| def save_local_data(data, filepath): | |
| with open(filepath, "w") as f: | |
| json.dump(data, f, indent=4) | |
| local_users = load_local_data(LOCAL_USERS_FILE) | |
| local_chats = load_local_data(LOCAL_CHATS_FILE) | |
| def hash_password(password): | |
| return hashlib.sha256(password.encode()).hexdigest() | |
| # ------------------------------ | |
| # π Auth Functions | |
| # ------------------------------ | |
| def authenticate_user(username, password): | |
| hashed = hash_password(password) | |
| if use_mongo: | |
| user_data = users_collection.find_one({"username": username}) | |
| else: | |
| user_data = local_users.get(username) | |
| if user_data and user_data["password"] == hashed: | |
| return user_data["user_id"], f"Welcome back, {username}!" | |
| return None, "Invalid username or password." | |
| def register_user(username, password): | |
| hashed = hash_password(password) | |
| new_user_id = str(uuid.uuid4()) | |
| if use_mongo: | |
| if users_collection.find_one({"username": username}): | |
| return None, "Username already taken." | |
| users_collection.insert_one({"username": username, "password": hashed, "user_id": new_user_id}) | |
| else: | |
| if username in local_users: | |
| return None, "Username already taken." | |
| local_users[username] = {"username": username, "password": hashed, "user_id": new_user_id} | |
| save_local_data(local_users, LOCAL_USERS_FILE) | |
| return new_user_id, f"Account created! Welcome, {username}." | |
| # ------------------------------ | |
| # πΎ Chat History Handling | |
| # ------------------------------ | |
| def get_history(user_id): | |
| if not user_id: | |
| return [] | |
| if use_mongo: | |
| convo = chats_collection.find_one({"user_id": user_id}) | |
| return convo.get("history", []) | |
| return local_chats.get(user_id, []) | |
| def save_history(user_id, history): | |
| if not user_id: | |
| return | |
| if use_mongo: | |
| chats_collection.update_one({"user_id": user_id}, {"$set": {"history": history}}, upsert=True) | |
| else: | |
| local_chats[user_id] = history | |
| save_local_data(local_chats, LOCAL_CHATS_FILE) | |
| def history_to_gradio_format(history: List[Tuple[str, str]]): | |
| gradio_msgs = [] | |
| for user_msg, bot_msg in history: | |
| if user_msg: | |
| gradio_msgs.append({"role": "user", "content": user_msg}) | |
| if bot_msg: | |
| gradio_msgs.append({"role": "assistant", "content": bot_msg}) | |
| return gradio_msgs | |
| # ------------------------------ | |
| # π€ Load Model | |
| # ------------------------------ | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| try: | |
| config = PeftConfig.from_pretrained(MODEL_NAME) | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| config.base_model_name_or_path, | |
| torch_dtype=torch.float16 if DEVICE == "cuda" else None, | |
| ).to(DEVICE) | |
| model = PeftModel.from_pretrained(base_model, MODEL_NAME) | |
| print(f"β Loaded PEFT model from {MODEL_NAME}") | |
| except Exception as e: | |
| print(f"β οΈ Failed to load PEFT adapters, fallback: {e}") | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE) | |
| model.eval() | |
| # ------------------------------ | |
| # π¬ Chat Logic | |
| # ------------------------------ | |
| def generate_response(user_input: str, state_dict: dict): | |
| """Main generation logic.""" | |
| if not state_dict.get("logged_in"): | |
| return [], state_dict, "Please log in to start chatting." | |
| user_id = state_dict["user_id"] | |
| history = state_dict.get("history", []) | |
| # Define system prompt here | |
| system_prompt = ( | |
| "You are SupportPal β a kind, compassionate AI that listens patiently, " | |
| "validates emotions, and gently helps people feel understood. " | |
| "Avoid talking about yourself, keep responses concise, warm, and natural." | |
| ) | |
| # Keep only the last 5 turns | |
| context_turns = history[-5:] | |
| # Build dialogue text | |
| prompt_turns = [] | |
| for turn in context_turns: | |
| if isinstance(turn, dict): | |
| prompt_turns.append(f"Human: {turn.get('user','')}\nAI: {turn.get('bot','')}") | |
| elif isinstance(turn, (list, tuple)) and len(turn) == 2: | |
| prompt_turns.append(f"Human: {turn[0]}\nAI: {turn[1]}") | |
| dialogue = system_prompt.strip() + "\n\n" + "\n".join(prompt_turns) | |
| dialogue += f"\nHuman: {user_input}\nAI:" | |
| # Tokenize and generate | |
| inputs = tokenizer(dialogue, return_tensors="pt", truncation=True, max_length=512).to(model.device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=200, | |
| temperature=0.7, | |
| top_k=50, | |
| top_p=0.9, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| response = decoded.split("AI:")[-1].strip() | |
| # Clean and finalize | |
| response = re.sub(r"You are SupportPal.*", "", response, flags=re.DOTALL).strip() | |
| response = re.sub(r"\b(I am|I'm) glad\b.*?\.", "", response) | |
| response = re.sub(r"\b(I hope|I wish)\b.*?\.", "", response) | |
| response = response.replace("AI:", "").strip() | |
| if not response: | |
| response = "I'm here with you. Tell me more about whatβs on your mind. π" | |
| if random.random() < 0.25 and not any(x in response for x in ["π", "β¨"]): | |
| response += " " + random.choice(["π", "β¨", "Youβre not alone in this."]) | |
| # Save and return | |
| history.append((user_input, response)) | |
| save_history(user_id, history) | |
| state_dict["history"] = history | |
| gradio_history = history_to_gradio_format(history) | |
| return gradio_history, state_dict, "" | |
| def process_chat(user_message, state_dict): | |
| """Wrapper used for Gradio callbacks.""" | |
| if not state_dict: | |
| state_dict = {"logged_in": False, "user_id": None, "username": None, "history": []} | |
| if not state_dict.get("logged_in"): | |
| return [], state_dict, "Please log in to start chatting." | |
| return generate_response(user_message, state_dict) | |
| # ------------------------------ | |
| # π Login / Register / Logout | |
| # ------------------------------ | |
| def handle_login(username, password, state): | |
| user_id, status = authenticate_user(username, password) | |
| if user_id: | |
| history = get_history(user_id) | |
| state.update({"logged_in": True, "user_id": user_id, "username": username, "history": history}) | |
| gr_history = history_to_gradio_format(history) | |
| return status, gr.update(visible=False), gr.update(visible=True), gr_history, state, f"Logged in as: **{username}**" | |
| return status, gr.update(visible=True), gr.update(visible=False), [], state, gr.update() | |
| def handle_register(username, password, state): | |
| user_id, status = register_user(username, password) | |
| if user_id: | |
| state.update({"logged_in": True, "user_id": user_id, "username": username, "history": []}) | |
| return status, gr.update(visible=False), gr.update(visible=True), [], state, f"Logged in as: **{username}**" | |
| return status, gr.update(visible=True), gr.update(visible=False), [], state, gr.update() | |
| def handle_logout(): | |
| return ( | |
| "Logged out successfully.", | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| [], | |
| {"logged_in": False, "user_id": None, "username": None, "history": []}, | |
| gr.update(value=""), | |
| "Logged in as:", | |
| ) | |
| def handle_clear_chat(state): | |
| if state.get("logged_in"): | |
| user_id = state["user_id"] | |
| save_history(user_id, []) | |
| state["history"] = [] | |
| return [], state | |
| return [], state | |
| # ------------------------------ | |
| # π§± Gradio UI | |
| # ------------------------------ | |
| with gr.Blocks(title="SupportPal Chatbot") as demo: | |
| storage_mode = "βοΈ MongoDB Connected" if use_mongo else "πΎ Local Storage" | |
| session_state = gr.State(value={"logged_in": False, "user_id": None, "username": None, "history": []}) | |
| gr.Markdown(f"## π€ SupportPal β Empathetic Chatbot\n**Persistence Mode:** {storage_mode}") | |
| # Auth Panel | |
| with gr.Row(visible=True) as auth_panel: | |
| with gr.Column(): | |
| auth_status = gr.Markdown("Please log in or register to start chatting.") | |
| username_box = gr.Textbox(label="Username") | |
| password_box = gr.Textbox(label="Password", type="password") | |
| with gr.Row(): | |
| login_btn = gr.Button("Login", variant="primary") | |
| register_btn = gr.Button("Register", variant="secondary") | |
| # Chat Panel | |
| with gr.Column(visible=False) as chat_panel: | |
| user_status = gr.Markdown("Logged in as:") | |
| chatbot_ui = gr.Chatbot(height=400, type="messages") | |
| with gr.Row(): | |
| msg = gr.Textbox(label="Your message...", scale=4) | |
| send_btn = gr.Button("Send", variant="primary", scale=1) | |
| with gr.Row(): | |
| clear_btn = gr.Button("Clear Chat", variant="secondary") | |
| logout_btn = gr.Button("Logout", variant="stop") | |
| # Bind events | |
| login_btn.click(handle_login, [username_box, password_box, session_state], | |
| [auth_status, auth_panel, chat_panel, chatbot_ui, session_state, user_status]) | |
| register_btn.click(handle_register, [username_box, password_box, session_state], | |
| [auth_status, auth_panel, chat_panel, chatbot_ui, session_state, user_status]) | |
| send_btn.click(process_chat, [msg, session_state], | |
| [chatbot_ui, session_state, auth_status]).then(lambda: gr.update(value=""), None, [msg]) | |
| msg.submit(process_chat, [msg, session_state], | |
| [chatbot_ui, session_state, auth_status]).then(lambda: gr.update(value=""), None, [msg]) | |
| clear_btn.click(handle_clear_chat, [session_state], [chatbot_ui, session_state]) | |
| logout_btn.click(handle_logout, None, | |
| [auth_status, auth_panel, chat_panel, chatbot_ui, session_state, msg, user_status]) | |
| demo.launch() |