supportpal / app.py
bhushanrocks's picture
trying v3 fix app.py
c32c6c4 verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel, PeftConfig
from pymongo import MongoClient
import torch
import os
import uuid
import json
import re
import random
import hashlib
from typing import List, Tuple, Dict, Any
# ------------------------------
# βš™οΈ Global Config
# ------------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_NAME = "bhushanrocks/supportpal-dialoGPT-v3"
# ------------------------------
# πŸ” MongoDB Setup & Local Fallback
# ------------------------------
MONGO_URI = os.getenv("MONGO_URI")
use_mongo = False
db = None
users_collection = None
chats_collection = None
if MONGO_URI:
try:
client = MongoClient(MONGO_URI, serverSelectionTimeoutMS=5000)
client.server_info()
db = client["supportpal_db"]
users_collection = db["users"]
chats_collection = db["chats"]
use_mongo = True
print("βœ… Connected to MongoDB for persistence and authentication.")
except Exception as e:
print(f"⚠️ MongoDB unavailable: {e}. Using local JSON storage.")
else:
print("⚠️ MONGO_URI not set. Using local JSON storage.")
LOCAL_USERS_FILE = "local_users.json"
LOCAL_CHATS_FILE = "local_chats.json"
def load_local_data(filepath):
if os.path.exists(filepath):
with open(filepath, "r") as f:
return json.load(f)
return {}
def save_local_data(data, filepath):
with open(filepath, "w") as f:
json.dump(data, f, indent=4)
local_users = load_local_data(LOCAL_USERS_FILE)
local_chats = load_local_data(LOCAL_CHATS_FILE)
def hash_password(password):
return hashlib.sha256(password.encode()).hexdigest()
# ------------------------------
# πŸ”‘ Auth Functions
# ------------------------------
def authenticate_user(username, password):
hashed = hash_password(password)
if use_mongo:
user_data = users_collection.find_one({"username": username})
else:
user_data = local_users.get(username)
if user_data and user_data["password"] == hashed:
return user_data["user_id"], f"Welcome back, {username}!"
return None, "Invalid username or password."
def register_user(username, password):
hashed = hash_password(password)
new_user_id = str(uuid.uuid4())
if use_mongo:
if users_collection.find_one({"username": username}):
return None, "Username already taken."
users_collection.insert_one({"username": username, "password": hashed, "user_id": new_user_id})
else:
if username in local_users:
return None, "Username already taken."
local_users[username] = {"username": username, "password": hashed, "user_id": new_user_id}
save_local_data(local_users, LOCAL_USERS_FILE)
return new_user_id, f"Account created! Welcome, {username}."
# ------------------------------
# πŸ’Ύ Chat History Handling
# ------------------------------
def get_history(user_id):
if not user_id:
return []
if use_mongo:
convo = chats_collection.find_one({"user_id": user_id})
return convo.get("history", [])
return local_chats.get(user_id, [])
def save_history(user_id, history):
if not user_id:
return
if use_mongo:
chats_collection.update_one({"user_id": user_id}, {"$set": {"history": history}}, upsert=True)
else:
local_chats[user_id] = history
save_local_data(local_chats, LOCAL_CHATS_FILE)
def history_to_gradio_format(history: List[Tuple[str, str]]):
gradio_msgs = []
for user_msg, bot_msg in history:
if user_msg:
gradio_msgs.append({"role": "user", "content": user_msg})
if bot_msg:
gradio_msgs.append({"role": "assistant", "content": bot_msg})
return gradio_msgs
# ------------------------------
# πŸ€— Load Model
# ------------------------------
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
try:
config = PeftConfig.from_pretrained(MODEL_NAME)
base_model = AutoModelForCausalLM.from_pretrained(
config.base_model_name_or_path,
torch_dtype=torch.float16 if DEVICE == "cuda" else None,
).to(DEVICE)
model = PeftModel.from_pretrained(base_model, MODEL_NAME)
print(f"βœ… Loaded PEFT model from {MODEL_NAME}")
except Exception as e:
print(f"⚠️ Failed to load PEFT adapters, fallback: {e}")
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE)
model.eval()
# ------------------------------
# πŸ’¬ Chat Logic
# ------------------------------
def generate_response(user_input: str, state_dict: dict):
"""Main generation logic."""
if not state_dict.get("logged_in"):
return [], state_dict, "Please log in to start chatting."
user_id = state_dict["user_id"]
history = state_dict.get("history", [])
# Define system prompt here
system_prompt = (
"You are SupportPal β€” a kind, compassionate AI that listens patiently, "
"validates emotions, and gently helps people feel understood. "
"Avoid talking about yourself, keep responses concise, warm, and natural."
)
# Keep only the last 5 turns
context_turns = history[-5:]
# Build dialogue text
prompt_turns = []
for turn in context_turns:
if isinstance(turn, dict):
prompt_turns.append(f"Human: {turn.get('user','')}\nAI: {turn.get('bot','')}")
elif isinstance(turn, (list, tuple)) and len(turn) == 2:
prompt_turns.append(f"Human: {turn[0]}\nAI: {turn[1]}")
dialogue = system_prompt.strip() + "\n\n" + "\n".join(prompt_turns)
dialogue += f"\nHuman: {user_input}\nAI:"
# Tokenize and generate
inputs = tokenizer(dialogue, return_tensors="pt", truncation=True, max_length=512).to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=200,
temperature=0.7,
top_k=50,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = decoded.split("AI:")[-1].strip()
# Clean and finalize
response = re.sub(r"You are SupportPal.*", "", response, flags=re.DOTALL).strip()
response = re.sub(r"\b(I am|I'm) glad\b.*?\.", "", response)
response = re.sub(r"\b(I hope|I wish)\b.*?\.", "", response)
response = response.replace("AI:", "").strip()
if not response:
response = "I'm here with you. Tell me more about what’s on your mind. πŸ’›"
if random.random() < 0.25 and not any(x in response for x in ["πŸ’›", "✨"]):
response += " " + random.choice(["πŸ’›", "✨", "You’re not alone in this."])
# Save and return
history.append((user_input, response))
save_history(user_id, history)
state_dict["history"] = history
gradio_history = history_to_gradio_format(history)
return gradio_history, state_dict, ""
def process_chat(user_message, state_dict):
"""Wrapper used for Gradio callbacks."""
if not state_dict:
state_dict = {"logged_in": False, "user_id": None, "username": None, "history": []}
if not state_dict.get("logged_in"):
return [], state_dict, "Please log in to start chatting."
return generate_response(user_message, state_dict)
# ------------------------------
# πŸ” Login / Register / Logout
# ------------------------------
def handle_login(username, password, state):
user_id, status = authenticate_user(username, password)
if user_id:
history = get_history(user_id)
state.update({"logged_in": True, "user_id": user_id, "username": username, "history": history})
gr_history = history_to_gradio_format(history)
return status, gr.update(visible=False), gr.update(visible=True), gr_history, state, f"Logged in as: **{username}**"
return status, gr.update(visible=True), gr.update(visible=False), [], state, gr.update()
def handle_register(username, password, state):
user_id, status = register_user(username, password)
if user_id:
state.update({"logged_in": True, "user_id": user_id, "username": username, "history": []})
return status, gr.update(visible=False), gr.update(visible=True), [], state, f"Logged in as: **{username}**"
return status, gr.update(visible=True), gr.update(visible=False), [], state, gr.update()
def handle_logout():
return (
"Logged out successfully.",
gr.update(visible=True),
gr.update(visible=False),
[],
{"logged_in": False, "user_id": None, "username": None, "history": []},
gr.update(value=""),
"Logged in as:",
)
def handle_clear_chat(state):
if state.get("logged_in"):
user_id = state["user_id"]
save_history(user_id, [])
state["history"] = []
return [], state
return [], state
# ------------------------------
# 🧱 Gradio UI
# ------------------------------
with gr.Blocks(title="SupportPal Chatbot") as demo:
storage_mode = "☁️ MongoDB Connected" if use_mongo else "πŸ’Ύ Local Storage"
session_state = gr.State(value={"logged_in": False, "user_id": None, "username": None, "history": []})
gr.Markdown(f"## πŸ€– SupportPal β€” Empathetic Chatbot\n**Persistence Mode:** {storage_mode}")
# Auth Panel
with gr.Row(visible=True) as auth_panel:
with gr.Column():
auth_status = gr.Markdown("Please log in or register to start chatting.")
username_box = gr.Textbox(label="Username")
password_box = gr.Textbox(label="Password", type="password")
with gr.Row():
login_btn = gr.Button("Login", variant="primary")
register_btn = gr.Button("Register", variant="secondary")
# Chat Panel
with gr.Column(visible=False) as chat_panel:
user_status = gr.Markdown("Logged in as:")
chatbot_ui = gr.Chatbot(height=400, type="messages")
with gr.Row():
msg = gr.Textbox(label="Your message...", scale=4)
send_btn = gr.Button("Send", variant="primary", scale=1)
with gr.Row():
clear_btn = gr.Button("Clear Chat", variant="secondary")
logout_btn = gr.Button("Logout", variant="stop")
# Bind events
login_btn.click(handle_login, [username_box, password_box, session_state],
[auth_status, auth_panel, chat_panel, chatbot_ui, session_state, user_status])
register_btn.click(handle_register, [username_box, password_box, session_state],
[auth_status, auth_panel, chat_panel, chatbot_ui, session_state, user_status])
send_btn.click(process_chat, [msg, session_state],
[chatbot_ui, session_state, auth_status]).then(lambda: gr.update(value=""), None, [msg])
msg.submit(process_chat, [msg, session_state],
[chatbot_ui, session_state, auth_status]).then(lambda: gr.update(value=""), None, [msg])
clear_btn.click(handle_clear_chat, [session_state], [chatbot_ui, session_state])
logout_btn.click(handle_logout, None,
[auth_status, auth_panel, chat_panel, chatbot_ui, session_state, msg, user_status])
demo.launch()