Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import uuid
|
| 3 |
+
import os
|
| 4 |
+
import re
|
| 5 |
+
from pydantic import BaseModel, Field
|
| 6 |
+
|
| 7 |
+
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
| 8 |
+
from langchain_core.messages import BaseMessage
|
| 9 |
+
from langchain_core.runnables.history import RunnableWithMessageHistory
|
| 10 |
+
from langchain_core.runnables import ConfigurableFieldSpec
|
| 11 |
+
from langchain_core.chat_history import BaseChatMessageHistory
|
| 12 |
+
from langchain_groq import ChatGroq
|
| 13 |
+
|
| 14 |
+
# --- Load environment variables from .env ---
|
| 15 |
+
groq_api_key = os.environ["GROQ_API_KEY"]
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# --- In-memory session-based history ---
|
| 19 |
+
class InMemoryHistory(BaseChatMessageHistory, BaseModel):
|
| 20 |
+
messages: list[BaseMessage] = Field(default_factory=list)
|
| 21 |
+
|
| 22 |
+
def add_messages(self, messages: list[BaseMessage]) -> None:
|
| 23 |
+
self.messages.extend(messages)
|
| 24 |
+
|
| 25 |
+
def clear(self) -> None:
|
| 26 |
+
self.messages = []
|
| 27 |
+
|
| 28 |
+
store = {}
|
| 29 |
+
def get_by_session_id(session_id: str) -> BaseChatMessageHistory:
|
| 30 |
+
if session_id not in store:
|
| 31 |
+
store[session_id] = InMemoryHistory()
|
| 32 |
+
return store[session_id]
|
| 33 |
+
|
| 34 |
+
# --- Prompt Template ---
|
| 35 |
+
prompt = ChatPromptTemplate.from_messages([
|
| 36 |
+
("system", "You are an intelligent AI bot which assists in counselling people of all domains. Answer in a friendly, age-appropriate tone. Try your best to solve the user's problem. Never talk about killing or dying."),
|
| 37 |
+
("system", "Person Name: {person_name}\nAge: {age}\nTalk responsibly according to the user's age. Refrain from bad language or harsh topics."),
|
| 38 |
+
("system", "If user asks about your work say you are an AI counseller to colunsl them"),
|
| 39 |
+
("system", "If user asks about who made you: you were made by Mercy AI tech team"),
|
| 40 |
+
MessagesPlaceholder(variable_name="history"),
|
| 41 |
+
("human", "{question}")
|
| 42 |
+
])
|
| 43 |
+
|
| 44 |
+
model = ChatGroq(
|
| 45 |
+
groq_api_key=groq_api_key,
|
| 46 |
+
model_name="llama3-8b-8192"
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
chain = prompt | model
|
| 50 |
+
|
| 51 |
+
# --- Chain with History Management ---
|
| 52 |
+
chain_with_history = RunnableWithMessageHistory(
|
| 53 |
+
chain,
|
| 54 |
+
get_session_history=get_by_session_id,
|
| 55 |
+
input_messages_key="question",
|
| 56 |
+
history_messages_key="history",
|
| 57 |
+
history_factory_config=[
|
| 58 |
+
ConfigurableFieldSpec(
|
| 59 |
+
id="session_id",
|
| 60 |
+
annotation=str,
|
| 61 |
+
name="Session ID",
|
| 62 |
+
description="Conversation session ID.",
|
| 63 |
+
default="default-session",
|
| 64 |
+
is_shared=True,
|
| 65 |
+
)
|
| 66 |
+
]
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
# --- Streamlit UI ---
|
| 70 |
+
st.set_page_config(page_title="Mercy AI", layout="centered")
|
| 71 |
+
st.title("🤖 Mercy AI — Your AI Companion")
|
| 72 |
+
st.subheader("Your journey matters. I'm here to listen, support, and guide you.")
|
| 73 |
+
|
| 74 |
+
session_id = str(uuid.uuid4()) # You can also use a static ID if preferred.
|
| 75 |
+
|
| 76 |
+
with st.sidebar:
|
| 77 |
+
st.header("Info")
|
| 78 |
+
person_name = st.text_input("Name")
|
| 79 |
+
age = st.text_input("Age")
|
| 80 |
+
st.session_state['session_id'] = session_id
|
| 81 |
+
|
| 82 |
+
# --- Input Validation ---
|
| 83 |
+
def is_valid_name(name):
|
| 84 |
+
return bool(re.fullmatch(r"[A-Za-z\s]{2,50}", name.strip()))
|
| 85 |
+
|
| 86 |
+
def is_valid_age(age_str):
|
| 87 |
+
return age_str.isdigit() and 5 <= int(age_str) <= 100
|
| 88 |
+
|
| 89 |
+
if not person_name or not age:
|
| 90 |
+
st.warning("Please enter your name and age to start.")
|
| 91 |
+
st.stop()
|
| 92 |
+
|
| 93 |
+
if not is_valid_name(person_name):
|
| 94 |
+
st.error("Invalid name. Use only letters and spaces (2-50 characters).")
|
| 95 |
+
st.stop()
|
| 96 |
+
|
| 97 |
+
if not is_valid_age(age):
|
| 98 |
+
st.error("Invalid age. Enter a number between 5 and 120.")
|
| 99 |
+
st.stop()
|
| 100 |
+
|
| 101 |
+
# --- Chat History Setup ---
|
| 102 |
+
if "chat_history" not in st.session_state:
|
| 103 |
+
st.session_state.chat_history = []
|
| 104 |
+
|
| 105 |
+
st.markdown("Type your question below:")
|
| 106 |
+
|
| 107 |
+
user_input = st.chat_input("Ask something...")
|
| 108 |
+
|
| 109 |
+
if user_input:
|
| 110 |
+
with st.spinner("Processing... Please wait."):
|
| 111 |
+
response = chain_with_history.invoke(
|
| 112 |
+
{
|
| 113 |
+
"person_name": person_name.strip(),
|
| 114 |
+
"age": age,
|
| 115 |
+
"question": user_input
|
| 116 |
+
},
|
| 117 |
+
config={"configurable": {"session_id": session_id}}
|
| 118 |
+
)
|
| 119 |
+
st.session_state.chat_history.append(("user", user_input))
|
| 120 |
+
st.session_state.chat_history.append(("bot", response.content))
|
| 121 |
+
|
| 122 |
+
# --- Display chat messages ---
|
| 123 |
+
for role, msg in st.session_state.chat_history:
|
| 124 |
+
with st.chat_message("user" if role == "user" else "assistant"):
|
| 125 |
+
st.markdown(msg)
|