Phi2-chatbox / app.py
Estherrr777's picture
Update app.py
134b220 verified
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
# Load model and tokenizer
model_name = "Estherrr777/phi2-chat"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# Create generation pipeline (CPU-safe)
generator = pipeline("text-generation", model=model, tokenizer=tokenizer, device=-1)
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli", device=-1)
danger_labels = ["self-harm", "suicide", "crisis", "mental health emergency", "postpartum danger"]
# Chat function with history + keyword detection
def chat_fn(message, history):
# Build full prompt with memory
full_prompt = ""
for user_msg, bot_reply in history:
full_prompt += f"User: {user_msg}\nAssistant: {bot_reply}\n"
full_prompt += f"User: {message}\nAssistant:"
# print("💬 Assistant is thinking... please wait, this might take up to 1–2 minutes on CPU.\n")
# Generate assistant reply
response = generator(
full_prompt,
max_new_tokens=100,
do_sample=True,
temperature=0.7,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id
)[0]["generated_text"]
new_reply = response[len(full_prompt):].strip()
# Zero-shot classify the message
result = classifier(message, candidate_labels=danger_labels, multi_label=True)
scores = {label: score for label, score in zip(result["labels"], result["scores"])}
# Trigger warning if any label is likely (adjust threshold if needed)
if any(score > 0.7 for score in scores.values()):
warning = "⚠️ This sounds serious. Please contact your healthcare provider or local emergency services immediately.\n\n"
return warning + new_reply
else:
return new_reply
# Gradio Chat UI
gr.ChatInterface(
fn=chat_fn,
title="🤰 Pregnancy Support Chat (Phi-2 LoRA)",
description="A calm, informative assistant fine-tuned on pregnancy-related conversations. Ask anything about pregnancy, anxiety, or health support.",
theme=gr.themes.Soft(primary_hue="rose", font=["Open Sans", "Arial"]),
).launch()