|
|
import gradio as gr |
|
|
import os |
|
|
import string |
|
|
from pymongo import MongoClient |
|
|
from openai import AsyncOpenAI, OpenAI |
|
|
import copy |
|
|
|
|
|
from constants import * |
|
|
import asyncio |
|
|
import string as st |
|
|
from opik.integrations.openai import track_openai |
|
|
from opik import track |
|
|
from bson.objectid import ObjectId |
|
|
import opik |
|
|
from textwrap import dedent |
|
|
|
|
|
|
|
|
oClient = opik.Opik() |
|
|
mdb = MongoClient( |
|
|
os.getenv("MONGO_URI") |
|
|
) |
|
|
aclient = track_openai(AsyncOpenAI()) |
|
|
client = track_openai(OpenAI()) |
|
|
db = mdb["Mindware"] |
|
|
|
|
|
|
|
|
def purge(d): |
|
|
""" |
|
|
Recursively collect all leaf nodes. |
|
|
""" |
|
|
result = {} |
|
|
for k, v in d.items(): |
|
|
if k == "chat_history": |
|
|
pass |
|
|
if isinstance(v, dict): |
|
|
result.update(purge(v)) |
|
|
elif isinstance(v, list): |
|
|
for idx, d in enumerate(v): |
|
|
|
|
|
if isinstance(d, dict): |
|
|
try: |
|
|
for k1 in d.keys(): |
|
|
result[k1] = [] |
|
|
for k2, v2 in d.items(): |
|
|
result[k2].append(v2) |
|
|
except Exception as e: |
|
|
print("Error! Error!", e) |
|
|
if k not in result.keys(): |
|
|
result[k] = [] |
|
|
result[k].append(d) |
|
|
else: |
|
|
result[k] = v |
|
|
|
|
|
else: |
|
|
result[k] = v |
|
|
else: |
|
|
result[k] = v |
|
|
return result |
|
|
|
|
|
|
|
|
def deploy(d): |
|
|
""" |
|
|
Recursively deploy all leaf nodes. |
|
|
""" |
|
|
result = {} |
|
|
result.update(purge(d)) |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
async def chat(prompt, model="gpt-4"): |
|
|
text = await aclient.chat.completions.create( |
|
|
model=model, messages=[{"role": "user", "content": prompt}] |
|
|
) |
|
|
return text.choices[0].message.content |
|
|
|
|
|
|
|
|
async def chat_generator(prompt: str, user: dict = None): |
|
|
response = await aclient.chat.completions.create( |
|
|
model="gpt-4", |
|
|
messages=[ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": prompt.format(**deploy(user)) if user else prompt, |
|
|
} |
|
|
], |
|
|
stream=True, |
|
|
) |
|
|
|
|
|
async for chunk in response: |
|
|
if chunk and chunk.choices[0].delta.content: |
|
|
|
|
|
|
|
|
yield chunk.choices[0].delta.content |
|
|
|
|
|
|
|
|
def get_or_init_user(reply, userId): |
|
|
|
|
|
users = list(db["users"].find({"userId": userId})) |
|
|
if not users: |
|
|
user = dict(**copy.deepcopy(USER_TEMPLATE)) |
|
|
user.update({"userId": userId}) |
|
|
user.update({"user_query": reply}) |
|
|
print("user created:", user) |
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
user = users[0] |
|
|
user.pop("_id") |
|
|
user = dict(users[0]) |
|
|
user.update({"user_query": reply}) |
|
|
|
|
|
return user |
|
|
|
|
|
|
|
|
async def search(query, n=5): |
|
|
embed = await aclient.embeddings.create(input=query, model="text-embedding-3-small") |
|
|
|
|
|
query_embedding = embed.data[0].embedding |
|
|
pipeline = [ |
|
|
{ |
|
|
"$vectorSearch": { |
|
|
"queryVector": query_embedding, |
|
|
"path": "embedding", |
|
|
"index": "arrestor", |
|
|
"score": {"$meta": "vectorSearchScore"}, |
|
|
"filter": {"class": "THERAPIST"}, |
|
|
"numCandidates": 850, |
|
|
"limit": n, |
|
|
} |
|
|
} |
|
|
] |
|
|
|
|
|
projection = [{"$project": {"embedding": 0}}] |
|
|
|
|
|
pipeline += projection |
|
|
|
|
|
docs = db["runway"].aggregate(pipeline) |
|
|
return list(docs) |
|
|
|
|
|
|
|
|
async def agentic_search(query, n=5): |
|
|
results = await search(query, n * 5) |
|
|
tasks = [chat(ARAG_PROMPT.format(query=query, doc=doc)) for doc in results] |
|
|
is_context = asyncio.gather(*tasks) |
|
|
|
|
|
docs = [] |
|
|
for doc, reply in zip(results, is_context): |
|
|
if reply == "True": |
|
|
docs.append(doc) |
|
|
if len(docs) >= n: |
|
|
break |
|
|
|
|
|
return list(docs) |
|
|
|
|
|
|
|
|
async def update_user(response, user): |
|
|
user["last_question"] = response |
|
|
user["chat_history"].append({"role": "user", "content": user["user_query"]}) |
|
|
user["chat_history"].append({"role": "assistant", "content": response}) |
|
|
db["users"].delete_many({"userId": user["userId"]}) |
|
|
db["users"].insert_one(user) |
|
|
|
|
|
print("Updated", user["userId"], user["user_query"], "reply:", response) |
|
|
|
|
|
|
|
|
async def add_background_tasks(task): |
|
|
"""a dummy wrapper to be replaced with FastAPI background tasks""" |
|
|
await task |
|
|
|
|
|
|
|
|
punc_removal = str.maketrans("", "", string.punctuation.replace("_", "")) |
|
|
|
|
|
|
|
|
async def escalate(user): |
|
|
print(f"user {user.get('name')} is not working, escalating to clinician") |
|
|
|
|
|
|
|
|
async def update_docs(user): |
|
|
print(f"updating docs for {user['name']}: {user['cache']}") |
|
|
|
|
|
|
|
|
@track |
|
|
async def handle_chat(reply, userId): |
|
|
""" |
|
|
Handle the chat response and update the user |
|
|
""" |
|
|
user = get_or_init_user(reply, userId) |
|
|
prompt = BASE_PROMPT |
|
|
tasks = [ |
|
|
chat(prompt=p.format(**deploy(user))) |
|
|
for p in [INTENT_PROMPT, RISK_PROMPT, CACHE_PROMPT, INTENSITY_PROMPT] |
|
|
] |
|
|
responses = await asyncio.gather(*tasks) |
|
|
|
|
|
if responses[2]: |
|
|
user["cache"] = responses[2] |
|
|
await add_background_tasks(update_docs(user)) |
|
|
|
|
|
intent = responses[0].upper().translate(punc_removal).replace(" ", "_") |
|
|
|
|
|
if intent == "ACTIVE_SPEAKING": |
|
|
prompt += SPEAKING_PROMPT |
|
|
elif intent == "VALIDATION_SEEK": |
|
|
prompt += VALIDATION_PROMPT |
|
|
elif intent == "OVERWHELMED": |
|
|
prompt += OVERWHELMED_PROMPT |
|
|
await asyncio.sleep(5) |
|
|
elif intent == "REMOTE_REFERRAL": |
|
|
results = await search(user["cache"], n=5) |
|
|
prompt += REMOTE_PROMPT.format(results=results, **deploy(user)) |
|
|
elif intent == "NEUTRAL_STOP": |
|
|
prompt += STOP_PROMPT |
|
|
elif intent == "END_OF_NARRATIVE": |
|
|
prompt += END_PROMPT |
|
|
else: |
|
|
print("Unknown response of intent detection:", responses[0]) |
|
|
|
|
|
if responses[1].upper().translate(punc_removal).replace(" ", "_") == "HIGH_RISK": |
|
|
prompt += HIGH_RISK_PROMPT |
|
|
await add_background_tasks(escalate(user)) |
|
|
|
|
|
response = "" |
|
|
async for word in chat_generator(prompt, user): |
|
|
if word: |
|
|
response += word |
|
|
yield word |
|
|
|
|
|
await add_background_tasks(update_user(response, user)) |
|
|
|
|
|
return |
|
|
|
|
|
|
|
|
async def respond(message, history, id): |
|
|
""" |
|
|
Respond to the chat message and return the response. |
|
|
""" |
|
|
reply = "" |
|
|
async for r in handle_chat(message, id): |
|
|
if r: |
|
|
reply += r |
|
|
yield reply |
|
|
|
|
|
|
|
|
""" |
|
|
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface |
|
|
""" |
|
|
with gr.Blocks() as demo: |
|
|
id = gr.Textbox(str(ObjectId()), label="userID") |
|
|
gr.ChatInterface( |
|
|
fn=respond, |
|
|
type="messages", |
|
|
additional_inputs=[id], |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(ssr_mode=False) |
|
|
|