File size: 6,643 Bytes
8bab08d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 |
# file: app/main.py
import json
from datetime import datetime
from typing import AsyncGenerator
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.encoders import jsonable_encoder
from app.schema import PipelineRequest, WriterStreamRequest, Prospect, HandoffPacket
from app.orchestrator import Orchestrator
from app.config import MODEL_NAME, HF_API_TOKEN
from app.logging_utils import setup_logging
from mcp.registry import MCPRegistry
from vector.store import VectorStore
import requests
setup_logging()
app = FastAPI(title="CX AI Agent", version="1.0.0")
orchestrator = Orchestrator()
mcp = MCPRegistry()
vector_store = VectorStore()
@app.on_event("startup")
async def startup():
"""Initialize connections on startup"""
await mcp.connect()
@app.get("/health")
async def health():
"""Health check with HF API connectivity test"""
try:
# Check HF API
hf_ok = bool(HF_API_TOKEN)
# Check MCP servers
mcp_status = await mcp.health_check()
return {
"status": "healthy",
"timestamp": datetime.utcnow().isoformat(),
"hf_inference": {
"configured": hf_ok,
"model": MODEL_NAME
},
"mcp": mcp_status,
"vector_store": vector_store.is_initialized()
}
except Exception as e:
return JSONResponse(
status_code=503,
content={"status": "unhealthy", "error": str(e)}
)
async def stream_pipeline(request: PipelineRequest) -> AsyncGenerator[bytes, None]:
"""
Stream NDJSON events from pipeline
Supports both dynamic (company_names) and legacy (company_ids) modes
"""
async for event in orchestrator.run_pipeline(
company_ids=request.company_ids,
company_names=request.company_names,
use_seed_file=request.use_seed_file
):
# Ensure nested Pydantic models (e.g., Prospect) are JSON-serializable
yield (json.dumps(jsonable_encoder(event)) + "\n").encode()
@app.post("/run")
async def run_pipeline(request: PipelineRequest):
"""
Run the full pipeline with NDJSON streaming
NEW: Accepts company_names for dynamic discovery
LEGACY: Still supports company_ids for backwards compatibility
Example (Dynamic):
{"company_names": ["Shopify", "Stripe", "Zendesk"]}
Example (Legacy):
{"company_ids": ["acme", "techcorp"], "use_seed_file": true}
"""
return StreamingResponse(
stream_pipeline(request),
media_type="application/x-ndjson"
)
async def stream_writer_test(company_id: str) -> AsyncGenerator[bytes, None]:
"""Stream only Writer agent output for testing"""
from agents.writer import Writer
# Get company from store
store = mcp.get_store_client()
company = await store.get_company(company_id)
if not company:
yield (json.dumps({"error": f"Company {company_id} not found"}) + "\n").encode()
return
# Create a test prospect
prospect = Prospect(
id=f"{company_id}_test",
company=company,
contacts=[],
facts=[],
fit_score=0.8,
status="scored"
)
writer = Writer(mcp)
async for event in writer.run_streaming(prospect):
# Ensure nested Pydantic models (e.g., Prospect) are JSON-serializable
yield (json.dumps(jsonable_encoder(event)) + "\n").encode()
@app.post("/writer/stream")
async def writer_stream_test(request: WriterStreamRequest):
"""Test endpoint for Writer streaming"""
return StreamingResponse(
stream_writer_test(request.company_id),
media_type="application/x-ndjson"
)
@app.get("/prospects")
async def list_prospects():
"""List all prospects with status and scores"""
store = mcp.get_store_client()
prospects = await store.list_prospects()
return {
"count": len(prospects),
"prospects": [
{
"id": p.id,
"company": p.company.name,
"status": p.status,
"fit_score": p.fit_score,
"contacts": len(p.contacts),
"facts": len(p.facts)
}
for p in prospects
]
}
@app.get("/prospects/{prospect_id}")
async def get_prospect(prospect_id: str):
"""Get detailed prospect information"""
store = mcp.get_store_client()
prospect = await store.get_prospect(prospect_id)
if not prospect:
raise HTTPException(status_code=404, detail="Prospect not found")
# Get thread if exists
email_client = mcp.get_email_client()
thread = None
if prospect.thread_id:
thread = await email_client.get_thread(prospect.id)
return {
"prospect": prospect.dict(),
"thread": thread.dict() if thread else None
}
@app.get("/handoff/{prospect_id}")
async def get_handoff(prospect_id: str):
"""Get handoff packet for a prospect"""
store = mcp.get_store_client()
prospect = await store.get_prospect(prospect_id)
if not prospect:
raise HTTPException(status_code=404, detail="Prospect not found")
if prospect.status != "ready_for_handoff":
raise HTTPException(status_code=400,
detail=f"Prospect not ready for handoff (status: {prospect.status})")
# Get thread
email_client = mcp.get_email_client()
thread = None
if prospect.thread_id:
thread = await email_client.get_thread(prospect.id)
# Get calendar slots
calendar_client = mcp.get_calendar_client()
slots = await calendar_client.suggest_slots()
packet = HandoffPacket(
prospect=prospect,
thread=thread,
calendar_slots=slots,
generated_at=datetime.utcnow()
)
return packet.dict()
@app.post("/reset")
async def reset_system():
"""Clear store, reload seeds, rebuild FAISS"""
store = mcp.get_store_client()
# Clear all data
await store.clear_all()
# Reload seed companies
import json
from app.config import COMPANIES_FILE
with open(COMPANIES_FILE) as f:
companies = json.load(f)
for company_data in companies:
await store.save_company(company_data)
# Rebuild vector index
vector_store.rebuild_index()
return {
"status": "reset_complete",
"companies_loaded": len(companies),
"timestamp": datetime.utcnow().isoformat()
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
|