Stock_Agent_optimized / utils /langgraph_conversation.py
cryogenic22's picture
Create utils/langgraph_conversation.py
b24c73d verified
# utils/langgraph_conversation.py
from langgraph.graph import StateGraph, END
from langchain.memory import ConversationBufferMemory
from langchain.prompts import ChatPromptTemplate
from langchain.chat_models import ChatAnthropic
from langchain.schema import HumanMessage, AIMessage
import streamlit as st
class ConversationalLearningGraph:
def __init__(self, anthropic_api_key):
self.llm = ChatAnthropic(anthropic_api_key=anthropic_api_key)
self.memory = ConversationBufferMemory(
memory_key="chat_history",
return_messages=True
)
self.graph = self._create_graph()
def _create_graph(self):
# Create the graph
workflow = StateGraph(StateGraph.from_empty())
# Add nodes for different conversation stages
workflow.add_node("understand_question", self._understand_question)
workflow.add_node("check_prerequisites", self._check_prerequisites)
workflow.add_node("generate_response", self._generate_response)
workflow.add_node("suggest_next_topics", self._suggest_next_topics)
# Define the edges
workflow.add_edge("understand_question", "check_prerequisites")
workflow.add_edge("check_prerequisites", "generate_response")
workflow.add_edge("generate_response", "suggest_next_topics")
workflow.add_edge("suggest_next_topics", END)
# Add conditional edges
workflow.add_conditional_edges(
"check_prerequisites",
self._needs_prerequisites,
{
True: "understand_question", # Loop back if prerequisites needed
False: "generate_response" # Continue if prerequisites met
}
)
return workflow.compile()
async def _understand_question(self, state):
"""Analyze and categorize the question"""
question = state['question']
prompt = ChatPromptTemplate.from_messages([
("system", "You are an expert at understanding trading questions."),
("human", "Analyze this trading question: {question}")
])
response = await self.llm.ainvoke(prompt.format_messages(question=question))
return {
**state,
"question_analysis": response.content,
"category": self._categorize_question(response.content)
}
def _check_prerequisites(self, state):
"""Check if user needs prerequisite knowledge"""
history = self.memory.chat_memory.messages
return {
**state,
"needs_prerequisites": self._evaluate_prerequisites(
state['category'],
history
)
}
async def _generate_response(self, state):
"""Generate a detailed response"""
prompt = ChatPromptTemplate.from_messages([
("system", "You are an expert trading educator."),
("human", """Given this trading question and context:
Question: {question}
Category: {category}
Previous discussion: {history}
Provide a detailed, educational response.""")
])
response = await self.llm.ainvoke(
prompt.format_messages(
question=state['question'],
category=state['category'],
history=self.memory.chat_memory.messages
)
)
return {
**state,
"response": response.content
}
async def _suggest_next_topics(self, state):
"""Suggest related topics to explore"""
prompt = ChatPromptTemplate.from_messages([
("system", "Suggest related trading topics to explore next."),
("human", """Based on:
Current topic: {question}
Response given: {response}
Suggest 3 related topics to explore next.""")
])
suggestions = await self.llm.ainvoke(
prompt.format_messages(
question=state['question'],
response=state['response']
)
)
return {
**state,
"next_topics": suggestions.content
}
def _needs_prerequisites(self, state):
"""Determine if prerequisites are needed"""
return state.get('needs_prerequisites', False)
def _categorize_question(self, analysis):
"""Categorize the question type"""
categories = [
"basic_concepts",
"technical_analysis",
"risk_management",
"trading_strategy",
"market_mechanics"
]
# Implement categorization logic
return "basic_concepts" # Placeholder
def _evaluate_prerequisites(self, category, history):
"""Evaluate if user needs prerequisites"""
# Implement prerequisite checking logic
return False # Placeholder
async def process_question(self, question):
"""Process a question through the conversation graph"""
# Add question to memory
self.memory.chat_memory.add_user_message(question)
# Initialize state
initial_state = {
"question": question,
"memory": self.memory
}
# Run the graph
final_state = await self.graph.arun(initial_state)
# Add response to memory
self.memory.chat_memory.add_ai_message(final_state['response'])
return {
'response': final_state['response'],
'next_topics': final_state['next_topics']
}