Spaces:
Sleeping
Sleeping
File size: 4,113 Bytes
fd02b49 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 | """
Agent API — FastAPI endpoints for demo and dashboard interaction.
Endpoints:
POST /run-episode - Run one episode and return metrics
POST /run-comparison - Run baseline vs memory comparison
GET /metrics - Get training/episode history
GET /memory/stats - Memory store statistics
GET /memory/search - Search memory for lessons
POST /memory/clear - Clear memory store
GET /health - Health check
"""
import json
import os
import sys
from pathlib import Path
from typing import Optional
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
sys.path.append(str(Path(__file__).resolve().parent.parent))
from memory.memory_store import MemoryStore
app = FastAPI(title="ToolMind Agent API", version="1.0.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
DATA_DIR = Path(__file__).resolve().parent.parent / "data"
METRICS_FILE = DATA_DIR / "training_log.json"
memory_store = MemoryStore(persist_dir=str(DATA_DIR / "chroma_data"))
class EpisodeRequest(BaseModel):
task_type: str = "hard"
use_memory: bool = True
episode_num: int = 0
class ComparisonRequest(BaseModel):
task_type: str = "hard"
num_episodes: int = 3
class MemorySearchRequest(BaseModel):
query: str
n_results: int = 3
def _load_metrics() -> list[dict]:
if METRICS_FILE.exists():
with open(METRICS_FILE) as f:
return json.load(f)
return []
def _save_metrics(metrics: list[dict]):
with open(METRICS_FILE, "w") as f:
json.dump(metrics, f, indent=2)
@app.get("/health")
def health():
return {
"status": "ok",
"memory_entries": memory_store.count(),
"metrics_entries": len(_load_metrics()),
}
@app.post("/run-episode")
def run_episode(req: EpisodeRequest):
"""Run a single episode and return results."""
try:
from agent.combined_agent import CombinedAgent
agent = CombinedAgent(
use_memory=req.use_memory,
memory_dir=str(DATA_DIR / "chroma_data"),
)
result = agent.run_episode(
task_type=req.task_type,
episode_num=req.episode_num,
verbose=False,
)
metrics = _load_metrics()
metrics.append(result)
_save_metrics(metrics)
return result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/run-comparison")
def run_comparison(req: ComparisonRequest):
"""Run baseline vs memory comparison."""
try:
from agent.combined_agent import CombinedAgent
agent = CombinedAgent(
use_memory=True,
memory_dir=str(DATA_DIR / "chroma_data"),
)
results = agent.run_comparison(
task_type=req.task_type,
num_episodes=req.num_episodes,
verbose=False,
)
return results
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/metrics")
def get_metrics():
"""Get all logged training/episode metrics."""
return _load_metrics()
@app.get("/memory/stats")
def memory_stats():
"""Get memory store statistics."""
return memory_store.get_stats()
@app.post("/memory/search")
def memory_search(req: MemorySearchRequest):
"""Search memory for relevant lessons."""
lessons = memory_store.retrieve_lessons(req.query, n_results=req.n_results)
formatted = memory_store.format_lessons_for_prompt(req.query, n_results=req.n_results)
return {
"lessons": lessons,
"formatted_prompt": formatted,
}
@app.get("/memory/all")
def memory_all():
"""Get all stored experiences."""
return memory_store.get_all_experiences(limit=200)
@app.post("/memory/clear")
def memory_clear():
"""Clear all memory."""
memory_store.clear()
return {"status": "cleared", "count": 0}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
|