MMHP / app.py
omar47's picture
Update app.py
6761fa6 verified
import gradio as gr
import os
import string
from pymongo import MongoClient
from openai import AsyncOpenAI, OpenAI
import copy
from constants import *
import asyncio
import string as st
from opik.integrations.openai import track_openai
from opik import track
from bson.objectid import ObjectId
import opik
from textwrap import dedent
oClient = opik.Opik()
mdb = MongoClient(
os.getenv("MONGO_URI")
) # , "mongodb://localhost:27017/")) # Default to localhost if not set
aclient = track_openai(AsyncOpenAI())
client = track_openai(OpenAI())
db = mdb["Mindware"]
def purge(d):
"""
Recursively collect all leaf nodes.
"""
result = {}
for k, v in d.items():
if k == "chat_history":
pass
if isinstance(v, dict):
result.update(purge(v))
elif isinstance(v, list):
for idx, d in enumerate(v):
if isinstance(d, dict):
try:
for k1 in d.keys():
result[k1] = []
for k2, v2 in d.items():
result[k2].append(v2)
except Exception as e:
print("Error! Error!", e)
if k not in result.keys():
result[k] = []
result[k].append(d)
else:
result[k] = v
else:
result[k] = v
else:
result[k] = v
return result
def deploy(d):
"""
Recursively deploy all leaf nodes.
"""
result = {}
result.update(purge(d))
return result
async def chat(prompt, model="gpt-4"):
text = await aclient.chat.completions.create(
model=model, messages=[{"role": "user", "content": prompt}]
)
return text.choices[0].message.content
async def chat_generator(prompt: str, user: dict = None):
response = await aclient.chat.completions.create(
model="gpt-4",
messages=[
{
"role": "user",
"content": prompt.format(**deploy(user)) if user else prompt,
}
],
stream=True,
)
# reply_chunks = []
async for chunk in response:
if chunk and chunk.choices[0].delta.content:
# reply_chunks.append(chunk.choices[0].delta.content)
yield chunk.choices[0].delta.content
def get_or_init_user(reply, userId):
users = list(db["users"].find({"userId": userId}))
if not users:
user = dict(**copy.deepcopy(USER_TEMPLATE))
user.update({"userId": userId})
user.update({"user_query": reply})
print("user created:", user)
# user.update({"chat_history": history})
else:
user = users[0]
user.pop("_id")
user = dict(users[0])
user.update({"user_query": reply})
return user
async def search(query, n=5):
embed = await aclient.embeddings.create(input=query, model="text-embedding-3-small")
query_embedding = embed.data[0].embedding
pipeline = [
{
"$vectorSearch": {
"queryVector": query_embedding,
"path": "embedding",
"index": "arrestor",
"score": {"$meta": "vectorSearchScore"},
"filter": {"class": "THERAPIST"},
"numCandidates": 850,
"limit": n,
}
}
]
projection = [{"$project": {"embedding": 0}}]
pipeline += projection
docs = db["runway"].aggregate(pipeline)
return list(docs)
async def agentic_search(query, n=5):
results = await search(query, n * 5)
tasks = [chat(ARAG_PROMPT.format(query=query, doc=doc)) for doc in results]
is_context = asyncio.gather(*tasks)
docs = []
for doc, reply in zip(results, is_context):
if reply == "True":
docs.append(doc)
if len(docs) >= n:
break
return list(docs)
async def update_user(response, user):
user["last_question"] = response
user["chat_history"].append({"role": "user", "content": user["user_query"]})
user["chat_history"].append({"role": "assistant", "content": response})
db["users"].delete_many({"userId": user["userId"]})
db["users"].insert_one(user)
print("Updated", user["userId"], user["user_query"], "reply:", response)
async def add_background_tasks(task):
"""a dummy wrapper to be replaced with FastAPI background tasks"""
await task
punc_removal = str.maketrans("", "", string.punctuation.replace("_", ""))
async def escalate(user):
print(f"user {user.get('name')} is not working, escalating to clinician")
async def update_docs(user):
print(f"updating docs for {user['name']}: {user['cache']}")
@track
async def handle_chat(reply, userId):
"""
Handle the chat response and update the user
"""
user = get_or_init_user(reply, userId)
prompt = BASE_PROMPT
tasks = [
chat(prompt=p.format(**deploy(user)))
for p in [INTENT_PROMPT, RISK_PROMPT, CACHE_PROMPT, INTENSITY_PROMPT]
]
responses = await asyncio.gather(*tasks)
if responses[2]:
user["cache"] = responses[2]
await add_background_tasks(update_docs(user))
intent = responses[0].upper().translate(punc_removal).replace(" ", "_")
if intent == "ACTIVE_SPEAKING":
prompt += SPEAKING_PROMPT
elif intent == "VALIDATION_SEEK":
prompt += VALIDATION_PROMPT
elif intent == "OVERWHELMED":
prompt += OVERWHELMED_PROMPT
await asyncio.sleep(5)
elif intent == "REMOTE_REFERRAL":
results = await search(user["cache"], n=5)
prompt += REMOTE_PROMPT.format(results=results, **deploy(user))
elif intent == "NEUTRAL_STOP":
prompt += STOP_PROMPT
elif intent == "END_OF_NARRATIVE":
prompt += END_PROMPT
else:
print("Unknown response of intent detection:", responses[0])
if responses[1].upper().translate(punc_removal).replace(" ", "_") == "HIGH_RISK":
prompt += HIGH_RISK_PROMPT
await add_background_tasks(escalate(user))
response = ""
async for word in chat_generator(prompt, user):
if word:
response += word
yield word
await add_background_tasks(update_user(response, user))
return
async def respond(message, history, id):
"""
Respond to the chat message and return the response.
"""
reply = ""
async for r in handle_chat(message, id):
if r:
reply += r
yield reply
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
with gr.Blocks() as demo:
id = gr.Textbox(str(ObjectId()), label="userID")
gr.ChatInterface(
fn=respond,
type="messages",
additional_inputs=[id],
)
if __name__ == "__main__":
demo.launch(ssr_mode=False)