tritan-api / api /planner.py
Madras1's picture
Upload 17 files
d6815ad verified
"""
Auto-Planner API - Generates workflows from natural language descriptions
"""
import os
import json
import httpx
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from typing import Optional
from dotenv import load_dotenv
load_dotenv()
router = APIRouter()
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
GROQ_API_URL = "https://api.groq.com/openai/v1/chat/completions"
SYSTEM_PROMPT = """You are a workflow planner for Tritan, an AI-native workflow engine.
Given a task description, create a workflow using these node types:
- trigger: Starting point of the workflow (ALWAYS required as first node)
- action: Execute a generic task
- condition: Branch based on a condition (has 'condition' field like "value > 10")
- loop: Iterate over items
- llm: AI language model call (has 'prompt', 'provider', 'model' fields)
- http: HTTP API request (has 'url', 'method' fields)
- code: Custom JavaScript code execution (has 'code' field)
- transform: Data transformation
Return ONLY valid JSON in this exact format, no markdown, no explanation:
{
"name": "Descriptive Workflow Name",
"nodes": [
{"id": "n1", "type": "trigger", "x": 100, "y": 150, "data": {}},
{"id": "n2", "type": "llm", "x": 350, "y": 150, "data": {"prompt": "Your prompt here", "provider": "groq", "model": "llama-3.3-70b-versatile"}},
{"id": "n3", "type": "http", "x": 600, "y": 150, "data": {"url": "https://api.example.com", "method": "POST"}}
],
"connections": [
{"from": "n1", "to": "n2"},
{"from": "n2", "to": "n3"}
]
}
Rules:
- Position nodes horizontally: x increases by 250 for each step
- Keep y around 100-200 for a clean layout
- Always start with a trigger node
- Make prompts specific and actionable
- Use realistic URLs for HTTP nodes
- For conditions, write clear boolean expressions"""
class PlannerRequest(BaseModel):
task: str
provider: Optional[str] = "groq"
model: Optional[str] = "llama-3.3-70b-versatile"
class PlannerResponse(BaseModel):
name: str
nodes: list
connections: list
@router.post("/generate", response_model=PlannerResponse)
async def generate_workflow(request: PlannerRequest):
"""Generate a workflow from a natural language task description."""
if not request.task.strip():
raise HTTPException(status_code=400, detail="Task description is required")
if not GROQ_API_KEY:
raise HTTPException(status_code=500, detail="GROQ_API_KEY not configured in backend")
try:
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.post(
GROQ_API_URL,
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {GROQ_API_KEY}"
},
json={
"model": request.model,
"messages": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": f"Create a workflow for this task: {request.task}"}
],
"temperature": 0.3,
"max_tokens": 2000
}
)
if response.status_code != 200:
error_data = response.json()
raise HTTPException(
status_code=response.status_code,
detail=error_data.get("error", {}).get("message", "API error")
)
data = response.json()
content = data["choices"][0]["message"]["content"]
# Extract JSON from response (handle markdown code blocks)
content = content.strip()
if content.startswith("```"):
content = content.split("```")[1]
if content.startswith("json"):
content = content[4:]
# Parse the workflow JSON
try:
workflow = json.loads(content)
except json.JSONDecodeError as e:
raise HTTPException(status_code=500, detail=f"Failed to parse workflow JSON: {str(e)}")
# Validate required fields
if "nodes" not in workflow:
raise HTTPException(status_code=500, detail="Generated workflow missing 'nodes'")
if "connections" not in workflow:
workflow["connections"] = []
if "name" not in workflow:
workflow["name"] = "Generated Workflow"
# Ensure all nodes have required fields
for i, node in enumerate(workflow["nodes"]):
if "id" not in node:
node["id"] = f"node_{i}"
if "data" not in node:
node["data"] = {}
if "x" not in node:
node["x"] = 100 + (i * 250)
if "y" not in node:
node["y"] = 150
return PlannerResponse(
name=workflow["name"],
nodes=workflow["nodes"],
connections=workflow["connections"]
)
except httpx.TimeoutException:
raise HTTPException(status_code=504, detail="Request to LLM timed out")
except httpx.RequestError as e:
raise HTTPException(status_code=500, detail=f"Request error: {str(e)}")