import streamlit as st import uuid from langgraph.graph import StateGraph, END, START from langgraph.checkpoint.memory import MemorySaver from typing_extensions import TypedDict from typing import Annotated, Dict, Any, List from langchain_core.prompts import ChatPromptTemplate from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.runnables import Runnable from langchain_google_genai import ChatGoogleGenerativeAI from langchain_groq import ChatGroq from langchain_together import Together import os from dotenv import load_dotenv from langgraph.checkpoint.memory import MemorySaver load_dotenv() groq_api_key = os.getenv("GROQ_API_KEY") together_api_key = os.getenv("TOGETHER_API_KEY") google_api_key = os.getenv("GOOGLE_API_KEY") # ------------------- # Define State # ------------------- class State(TypedDict): messages: List[str] answers: List[str] retry_count: int questions: List[str] code: str explanation: str task_plan: str user_input: str # ------------------- # Initialize LLM # ------------------- # llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0) # Replace with ChatGroq or Together if needed question_model = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0.7, google_api_key=google_api_key) llm_agent_model = Together(model="mistralai/Mistral-7B-Instruct-v0.1") code_model = Together(model="deepseek-ai/deepseek-coder-6.7b-instruct",api_key=together_api_key) confirm_model = ChatGroq(model="qwen/qwen3-32b",api_key=groq_api_key) explain_model = ChatGroq(model="meta-llama/llama-guard-4-12b",api_key=groq_api_key) # ------------------- # Define Node Functions # ------------------- def llm_agent(state: State) -> State: messages = [ SystemMessage(content="You are an AI task planner. Break down user instructions."), HumanMessage(content=state["user_input"]) ] response = llm_agent_model.invoke(messages) state["task_plan"] = response.content return state def generate_questions(state: State) -> State: messages = [ SystemMessage(content="You generate follow-up questions to clarify vague instructions."), HumanMessage(content=state["answers"][0]) ] response = question_model.invoke(messages) state["questions"] = response.content return state def generate_code(state: State) -> State: messages = [ SystemMessage(content="You are a coding expert. Generate clean, well-documented Python code."), HumanMessage(content=state["answers"][0]) ] response = code_model.invoke(messages) state["code"] = response.content return state def handle_answers(state: State) -> State: print("Handling answers...") answer = state["answers"][0] system_prompt = "You are a helpful assistant that confirms the received idea." user_msg = f"The user said: '{answer}'. Confirm and move ahead." response = confirm_model.invoke([ SystemMessage(content=system_prompt), HumanMessage(content=user_msg) ]) if 'messages' not in state: state['messages'] =[] state["messages"].append(response.content.strip()) return state def explain_code(state: State) -> State: print("Explaining code...") code = state["code"] system_prompt = "You are a Python tutor. Explain what the following code does in simple terms." user_msg = f"Code:\n{code}" response = explain_model.invoke([ SystemMessage(content=system_prompt), HumanMessage(content=user_msg) ]) state["explanation"] = response.content.strip() return state def wait_for_answers(state: State) -> State: print("Waiting for answers...") state["retry_count"] = state.get("retry_count", 0) + 1 # Simulate receiving an answer after 2 retries if state["retry_count"] >= 2: state["answers"] = ["Build a calculator app"] return state # ------------------- # Define Condition Function # ------------------- MAX_RETRIES = 3 def check_if_answered(state: State) -> str: if "answers" in state and state["answers"]: return "answered" elif state.get("retry_count", 0) >= MAX_RETRIES: print("Max retries reached. Proceeding anyway.") return "answered" else: return "not_answered" # ------------------- # Build the Graph # ------------------- builder = StateGraph(State) builder.add_node("LLM_Agent", llm_agent) builder.add_node("Generate_Questions", generate_questions) builder.add_node("Wait_For_Answers", wait_for_answers) builder.add_node("Handle_Answers", handle_answers) builder.add_node("Generate_Code", generate_code) builder.add_node("Code_Explainer", explain_code) builder.set_entry_point("LLM_Agent") builder.add_edge("LLM_Agent", "Generate_Questions") builder.add_conditional_edges( "Generate_Questions", check_if_answered, { "answered": "Handle_Answers", "not_answered": "Wait_For_Answers" } ) builder.add_edge("Wait_For_Answers", "Generate_Questions") builder.add_edge("Handle_Answers", "Generate_Code") builder.add_edge("Generate_Code", "Code_Explainer") builder.add_edge("Code_Explainer", END) # graph = builder.compile(checkpointer = MemorySaver()) # graph = StateGraph(State) # ------------------- # Compile and Run # ------------------- # memory = MemorySaver() # graph = builder.compile(checkpointer=memory) # inputs = {"messages": [], "answers": [], "retry_count": 0, "code": "", "explanation": "", "questions": [], "task_plan" :"","user_input": "I want to create an agent"} # for step in graph.stream(inputs, configurable={"thread_id": st.session_state.thread_id}): # for key, val in step.items(): # print(f"\n--- {key} ---\n{val}") st.set_page_config(page_title="MitraVerse", layout="wide") if "thread_id" not in st.session_state: st.session_state.thread_id = str(uuid.uuid4()) if "chat_history" not in st.session_state: st.session_state.chat_history = [] memory = MemorySaver() graph = builder.compile(checkpointer=memory) st.markdown(""" """, unsafe_allow_html=True) st.title("🧠MitraVerse") # Columns for button layout col1, col2, col3 = st.columns(3) # Initialize session # if "thread_id" not in st.session_state: # st.session_state.thread_id = "1" # if "chat_history" not in st.session_state: # st.session_state.chat_history = [] # Show chat for msg in st.session_state.chat_history: role = "user" if isinstance(msg, HumanMessage) else "bot" st.markdown(f"
{msg.content}
",unsafe_allow_html=True) user_input = st.chat_input("What would you like to build?") if user_input: st.session_state.chat_history.append(HumanMessage(content=user_input)) state_input = { "messages": [], "answers": [user_input], "retry_count": 0, "code": "", "explanation": "", "questions": [], "task_plan": "", "user_input": user_input, } for step in graph.stream(state_input, configurable={"thread_id": st.session_state.thread_id}): for _, state in step.items(): # Display messages from different stages if "task_plan" in state and state["task_plan"]: st.session_state.chat_history.append(SystemMessage(content=f"🔧 Task Plan:\n{state['task_plan']}")) if "questions" in state and state["questions"]: st.session_state.chat_history.append(SystemMessage(content=f"❓ Questions:\n{state['questions']}")) if "code" in state and state["code"]: st.session_state.chat_history.append(SystemMessage(content=f"💻 Code:\n```python\n{state['code']}\n```")) if "explanation" in state and state["explanation"]: st.session_state.chat_history.append(SystemMessage(content=f"📘 Explanation:\n{state['explanation']}"))