imagineui / src /agents /workflow.py
Technologic101's picture
task: prototype RAG and design creation
c5273f3
raw
history blame
4.65 kB
from typing import Dict, List, Annotated, TypedDict
from langchain_core.messages import HumanMessage, AIMessage
from langgraph.graph import StateGraph, END
from langchain_core.messages import BaseMessage
from chains.design_rag import DesignRAG
from langchain.prompts import ChatPromptTemplate
import chainlit as cl
import json
# Define state types
class AgentState(TypedDict):
messages: List[BaseMessage]
html_content: str
style_requirements: Dict
css_output: str | None
# Node functions
async def conversation_node(state: AgentState, rag: DesignRAG):
"""Handle conversation and requirement gathering"""
# Get last message
last_message = state["messages"][-1]
if not isinstance(last_message, HumanMessage):
return {"messages": state["messages"]}
# Check for style requirements readiness
if "!generate" in last_message.content.lower():
# Extract style requirements from conversation
requirements = await extract_requirements(state["messages"], rag)
return {
"messages": state["messages"],
"style_requirements": requirements,
"next": "generate_css"
}
# Normal conversation - get context and respond
response = await rag.query(last_message.content)
state["messages"].append(AIMessage(content=response))
return {"messages": state["messages"]}
async def extract_requirements(
messages: List[BaseMessage],
rag: DesignRAG
) -> Dict:
"""Extract style requirements from conversation history"""
# Combine messages into context
context = "\n".join([
f"{'User' if isinstance(m, HumanMessage) else 'Assistant'}: {m.content}"
for m in messages
])
# Create extraction prompt
prompt = f"""Based on this conversation, extract the key style requirements:
{context}
Provide the requirements in this JSON format:
{{
"style_description": "Brief description of desired style",
"key_elements": ["list", "of", "important", "visual", "elements"],
"color_scheme": "Description of colors",
"layout_preferences": "Any specific layout requirements",
"mood": "Desired emotional impact"
}}
"""
# Get requirements through RAG system
response = await rag.llm.ainvoke(prompt)
return json.loads(response.content)
async def generate_css_node(state: AgentState, rag: DesignRAG) -> AgentState:
"""Generate CSS based on requirements"""
# Get similar designs based on requirements
similar_designs = await rag.query_similar_designs(state["style_requirements"])
# Create the generation prompt
prompt = ChatPromptTemplate.from_template("""You are an expert CSS designer creating a style for the following HTML structure:
HTML Structure:
{html_content}
Style Requirements:
{requirements}
Similar Design Examples:
{examples}
Generate a complete CSS file that:
1. Implements the requested style requirements
2. Uses modern CSS features appropriately
3. Creates a cohesive and polished design
4. Includes comments explaining key style decisions
Respond only with the CSS code, starting with a comment block describing the design approach.
""")
# Format requirements for prompt
requirements_text = json.dumps(state["style_requirements"], indent=2)
# Generate CSS
response = await rag.llm.ainvoke(
prompt.format(
html_content=state["html_content"],
requirements=requirements_text,
examples=similar_designs
)
)
# Store generated CSS
state["css_output"] = response.content
# Add completion message
state["messages"].append(AIMessage(content="""I've generated the CSS based on your requirements.
Here's what I created:
```css
{css}
```
Would you like me to explain any part of the design or make any adjustments?
""".format(css=response.content)))
return state
def create_graph(rag: DesignRAG) -> StateGraph:
"""Create the workflow graph"""
# Create graph
workflow = StateGraph(AgentState)
# Add nodes
workflow.add_node("conversation", lambda s: conversation_node(s, rag))
workflow.add_node("generate_css", lambda s: generate_css_node(s, rag))
# Add edges
workflow.add_edge("conversation", "conversation")
workflow.add_edge("conversation", "generate_css")
workflow.add_edge("generate_css", END)
# Set entry point
workflow.set_entry_point("conversation")
return workflow.compile()