Spaces:
Sleeping
Sleeping
File size: 2,764 Bytes
c01955c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 | # ===================== IMPORTS =====================
import logging
from langgraph.graph import StateGraph, START, END
from langchain.messages import HumanMessage, AIMessage, SystemMessage
from src.Agents.checkpointer import get_checkpointer,_checkpointer
from src.Agents.nodes.question_generator_node import generate_questions
from src.Agents.nodes.chat_node import chat
from src.Agents.models.interview_model import ChatState
builder = StateGraph(ChatState)
builder.add_node("chat", chat)
builder.add_edge(START, "chat")
builder.add_edge("chat", END)
_compiled_graph = None
async def get_graph():
global _compiled_graph
if _compiled_graph is None:
cp = await get_checkpointer()
_compiled_graph = builder.compile(checkpointer=cp)
with open("interview_graph.png","wb") as f:
f.write(_compiled_graph.get_graph().draw_mermaid_png())
return _compiled_graph
# ===================== Chat Interviewer FUNCTION =====================
async def chat_interviewer(thread_id: str, time_remain: int, topic: str, user_input: str):
graph = await get_graph()
result = await graph.ainvoke(
{
"messages": [HumanMessage(content=user_input)],
"topic": topic,
"time_remaining": time_remain
},
config={"configurable": {"thread_id": thread_id}}
)
return result
async def load_conversation(thread_id: str):
graph = await get_graph()
state = await graph.aget_state(config={'configurable': {'thread_id': thread_id}})
return state.values.get('messages', [])
async def deleteThread(thread_id: str):
try:
cp = await get_checkpointer()
# Check if thread exists first
state = await cp.aget_tuple(config={'configurable': {'thread_id': thread_id}})
if state is None:
logging.info(f"Thread {thread_id} not found, nothing to delete.")
return False
await cp.adelete_thread(thread_id=thread_id)
logging.info(f"Thread {thread_id} deleted successfully.")
return True
except Exception as e:
logging.error(f"Error deleting thread {thread_id}: {e}")
return False
async def close_checkpointer():
global _checkpointer
if _checkpointer is not None:
try:
logging.info("Closing checkpointer database connection...")
# AsyncSqliteSaver has a 'conn' attribute which is an aiosqlite connection
if hasattr(_checkpointer, 'conn') and _checkpointer.conn:
await _checkpointer.conn.close()
_checkpointer = None
logging.info("Checkpointer connection closed.")
except Exception as e:
logging.error(f"Error closing checkpointer: {e}")
|