import gradio as gr import os import string from pymongo import MongoClient from openai import AsyncOpenAI, OpenAI import copy from constants import * import asyncio import string as st from opik.integrations.openai import track_openai from opik import track from bson.objectid import ObjectId import opik from textwrap import dedent oClient = opik.Opik() mdb = MongoClient( os.getenv("MONGO_URI") ) # , "mongodb://localhost:27017/")) # Default to localhost if not set aclient = track_openai(AsyncOpenAI()) client = track_openai(OpenAI()) db = mdb["Mindware"] def purge(d): """ Recursively collect all leaf nodes. """ result = {} for k, v in d.items(): if k == "chat_history": pass if isinstance(v, dict): result.update(purge(v)) elif isinstance(v, list): for idx, d in enumerate(v): if isinstance(d, dict): try: for k1 in d.keys(): result[k1] = [] for k2, v2 in d.items(): result[k2].append(v2) except Exception as e: print("Error! Error!", e) if k not in result.keys(): result[k] = [] result[k].append(d) else: result[k] = v else: result[k] = v else: result[k] = v return result def deploy(d): """ Recursively deploy all leaf nodes. """ result = {} result.update(purge(d)) return result async def chat(prompt, model="gpt-4"): text = await aclient.chat.completions.create( model=model, messages=[{"role": "user", "content": prompt}] ) return text.choices[0].message.content async def chat_generator(prompt: str, user: dict = None): response = await aclient.chat.completions.create( model="gpt-4", messages=[ { "role": "user", "content": prompt.format(**deploy(user)) if user else prompt, } ], stream=True, ) # reply_chunks = [] async for chunk in response: if chunk and chunk.choices[0].delta.content: # reply_chunks.append(chunk.choices[0].delta.content) yield chunk.choices[0].delta.content def get_or_init_user(reply, userId): users = list(db["users"].find({"userId": userId})) if not users: user = dict(**copy.deepcopy(USER_TEMPLATE)) user.update({"userId": userId}) user.update({"user_query": reply}) print("user created:", user) # user.update({"chat_history": history}) else: user = users[0] user.pop("_id") user = dict(users[0]) user.update({"user_query": reply}) return user async def search(query, n=5): embed = await aclient.embeddings.create(input=query, model="text-embedding-3-small") query_embedding = embed.data[0].embedding pipeline = [ { "$vectorSearch": { "queryVector": query_embedding, "path": "embedding", "index": "arrestor", "score": {"$meta": "vectorSearchScore"}, "filter": {"class": "THERAPIST"}, "numCandidates": 850, "limit": n, } } ] projection = [{"$project": {"embedding": 0}}] pipeline += projection docs = db["runway"].aggregate(pipeline) return list(docs) async def agentic_search(query, n=5): results = await search(query, n * 5) tasks = [chat(ARAG_PROMPT.format(query=query, doc=doc)) for doc in results] is_context = asyncio.gather(*tasks) docs = [] for doc, reply in zip(results, is_context): if reply == "True": docs.append(doc) if len(docs) >= n: break return list(docs) async def update_user(response, user): user["last_question"] = response user["chat_history"].append({"role": "user", "content": user["user_query"]}) user["chat_history"].append({"role": "assistant", "content": response}) db["users"].delete_many({"userId": user["userId"]}) db["users"].insert_one(user) print("Updated", user["userId"], user["user_query"], "reply:", response) async def add_background_tasks(task): """a dummy wrapper to be replaced with FastAPI background tasks""" await task punc_removal = str.maketrans("", "", string.punctuation.replace("_", "")) async def escalate(user): print(f"user {user.get('name')} is not working, escalating to clinician") async def update_docs(user): print(f"updating docs for {user['name']}: {user['cache']}") @track async def handle_chat(reply, userId): """ Handle the chat response and update the user """ user = get_or_init_user(reply, userId) prompt = BASE_PROMPT tasks = [ chat(prompt=p.format(**deploy(user))) for p in [INTENT_PROMPT, RISK_PROMPT, CACHE_PROMPT, INTENSITY_PROMPT] ] responses = await asyncio.gather(*tasks) if responses[2]: user["cache"] = responses[2] await add_background_tasks(update_docs(user)) intent = responses[0].upper().translate(punc_removal).replace(" ", "_") if intent == "ACTIVE_SPEAKING": prompt += SPEAKING_PROMPT elif intent == "VALIDATION_SEEK": prompt += VALIDATION_PROMPT elif intent == "OVERWHELMED": prompt += OVERWHELMED_PROMPT await asyncio.sleep(5) elif intent == "REMOTE_REFERRAL": results = await search(user["cache"], n=5) prompt += REMOTE_PROMPT.format(results=results, **deploy(user)) elif intent == "NEUTRAL_STOP": prompt += STOP_PROMPT elif intent == "END_OF_NARRATIVE": prompt += END_PROMPT else: print("Unknown response of intent detection:", responses[0]) if responses[1].upper().translate(punc_removal).replace(" ", "_") == "HIGH_RISK": prompt += HIGH_RISK_PROMPT await add_background_tasks(escalate(user)) response = "" async for word in chat_generator(prompt, user): if word: response += word yield word await add_background_tasks(update_user(response, user)) return async def respond(message, history, id): """ Respond to the chat message and return the response. """ reply = "" async for r in handle_chat(message, id): if r: reply += r yield reply """ For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface """ with gr.Blocks() as demo: id = gr.Textbox(str(ObjectId()), label="userID") gr.ChatInterface( fn=respond, type="messages", additional_inputs=[id], ) if __name__ == "__main__": demo.launch(ssr_mode=False)