cx_ai_agent_v1 / app /main.py
muzakkirhussain011's picture
Add application files (text files only)
8bab08d
# file: app/main.py
import json
from datetime import datetime
from typing import AsyncGenerator
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.encoders import jsonable_encoder
from app.schema import PipelineRequest, WriterStreamRequest, Prospect, HandoffPacket
from app.orchestrator import Orchestrator
from app.config import MODEL_NAME, HF_API_TOKEN
from app.logging_utils import setup_logging
from mcp.registry import MCPRegistry
from vector.store import VectorStore
import requests
setup_logging()
app = FastAPI(title="CX AI Agent", version="1.0.0")
orchestrator = Orchestrator()
mcp = MCPRegistry()
vector_store = VectorStore()
@app.on_event("startup")
async def startup():
"""Initialize connections on startup"""
await mcp.connect()
@app.get("/health")
async def health():
"""Health check with HF API connectivity test"""
try:
# Check HF API
hf_ok = bool(HF_API_TOKEN)
# Check MCP servers
mcp_status = await mcp.health_check()
return {
"status": "healthy",
"timestamp": datetime.utcnow().isoformat(),
"hf_inference": {
"configured": hf_ok,
"model": MODEL_NAME
},
"mcp": mcp_status,
"vector_store": vector_store.is_initialized()
}
except Exception as e:
return JSONResponse(
status_code=503,
content={"status": "unhealthy", "error": str(e)}
)
async def stream_pipeline(request: PipelineRequest) -> AsyncGenerator[bytes, None]:
"""
Stream NDJSON events from pipeline
Supports both dynamic (company_names) and legacy (company_ids) modes
"""
async for event in orchestrator.run_pipeline(
company_ids=request.company_ids,
company_names=request.company_names,
use_seed_file=request.use_seed_file
):
# Ensure nested Pydantic models (e.g., Prospect) are JSON-serializable
yield (json.dumps(jsonable_encoder(event)) + "\n").encode()
@app.post("/run")
async def run_pipeline(request: PipelineRequest):
"""
Run the full pipeline with NDJSON streaming
NEW: Accepts company_names for dynamic discovery
LEGACY: Still supports company_ids for backwards compatibility
Example (Dynamic):
{"company_names": ["Shopify", "Stripe", "Zendesk"]}
Example (Legacy):
{"company_ids": ["acme", "techcorp"], "use_seed_file": true}
"""
return StreamingResponse(
stream_pipeline(request),
media_type="application/x-ndjson"
)
async def stream_writer_test(company_id: str) -> AsyncGenerator[bytes, None]:
"""Stream only Writer agent output for testing"""
from agents.writer import Writer
# Get company from store
store = mcp.get_store_client()
company = await store.get_company(company_id)
if not company:
yield (json.dumps({"error": f"Company {company_id} not found"}) + "\n").encode()
return
# Create a test prospect
prospect = Prospect(
id=f"{company_id}_test",
company=company,
contacts=[],
facts=[],
fit_score=0.8,
status="scored"
)
writer = Writer(mcp)
async for event in writer.run_streaming(prospect):
# Ensure nested Pydantic models (e.g., Prospect) are JSON-serializable
yield (json.dumps(jsonable_encoder(event)) + "\n").encode()
@app.post("/writer/stream")
async def writer_stream_test(request: WriterStreamRequest):
"""Test endpoint for Writer streaming"""
return StreamingResponse(
stream_writer_test(request.company_id),
media_type="application/x-ndjson"
)
@app.get("/prospects")
async def list_prospects():
"""List all prospects with status and scores"""
store = mcp.get_store_client()
prospects = await store.list_prospects()
return {
"count": len(prospects),
"prospects": [
{
"id": p.id,
"company": p.company.name,
"status": p.status,
"fit_score": p.fit_score,
"contacts": len(p.contacts),
"facts": len(p.facts)
}
for p in prospects
]
}
@app.get("/prospects/{prospect_id}")
async def get_prospect(prospect_id: str):
"""Get detailed prospect information"""
store = mcp.get_store_client()
prospect = await store.get_prospect(prospect_id)
if not prospect:
raise HTTPException(status_code=404, detail="Prospect not found")
# Get thread if exists
email_client = mcp.get_email_client()
thread = None
if prospect.thread_id:
thread = await email_client.get_thread(prospect.id)
return {
"prospect": prospect.dict(),
"thread": thread.dict() if thread else None
}
@app.get("/handoff/{prospect_id}")
async def get_handoff(prospect_id: str):
"""Get handoff packet for a prospect"""
store = mcp.get_store_client()
prospect = await store.get_prospect(prospect_id)
if not prospect:
raise HTTPException(status_code=404, detail="Prospect not found")
if prospect.status != "ready_for_handoff":
raise HTTPException(status_code=400,
detail=f"Prospect not ready for handoff (status: {prospect.status})")
# Get thread
email_client = mcp.get_email_client()
thread = None
if prospect.thread_id:
thread = await email_client.get_thread(prospect.id)
# Get calendar slots
calendar_client = mcp.get_calendar_client()
slots = await calendar_client.suggest_slots()
packet = HandoffPacket(
prospect=prospect,
thread=thread,
calendar_slots=slots,
generated_at=datetime.utcnow()
)
return packet.dict()
@app.post("/reset")
async def reset_system():
"""Clear store, reload seeds, rebuild FAISS"""
store = mcp.get_store_client()
# Clear all data
await store.clear_all()
# Reload seed companies
import json
from app.config import COMPANIES_FILE
with open(COMPANIES_FILE) as f:
companies = json.load(f)
for company_data in companies:
await store.save_company(company_data)
# Rebuild vector index
vector_store.rebuild_index()
return {
"status": "reset_complete",
"companies_loaded": len(companies),
"timestamp": datetime.utcnow().isoformat()
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)