|
|
|
|
|
import json |
|
|
from datetime import datetime |
|
|
from typing import AsyncGenerator |
|
|
from fastapi import FastAPI, HTTPException |
|
|
from fastapi.responses import StreamingResponse, JSONResponse |
|
|
from fastapi.encoders import jsonable_encoder |
|
|
from app.schema import PipelineRequest, WriterStreamRequest, Prospect, HandoffPacket |
|
|
from app.orchestrator import Orchestrator |
|
|
from app.config import MODEL_NAME, HF_API_TOKEN |
|
|
from app.logging_utils import setup_logging |
|
|
from mcp.registry import MCPRegistry |
|
|
from vector.store import VectorStore |
|
|
import requests |
|
|
|
|
|
setup_logging() |
|
|
|
|
|
app = FastAPI(title="CX AI Agent", version="1.0.0") |
|
|
orchestrator = Orchestrator() |
|
|
mcp = MCPRegistry() |
|
|
vector_store = VectorStore() |
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup(): |
|
|
"""Initialize connections on startup""" |
|
|
await mcp.connect() |
|
|
|
|
|
@app.get("/health") |
|
|
async def health(): |
|
|
"""Health check with HF API connectivity test""" |
|
|
try: |
|
|
|
|
|
hf_ok = bool(HF_API_TOKEN) |
|
|
|
|
|
|
|
|
mcp_status = await mcp.health_check() |
|
|
|
|
|
return { |
|
|
"status": "healthy", |
|
|
"timestamp": datetime.utcnow().isoformat(), |
|
|
"hf_inference": { |
|
|
"configured": hf_ok, |
|
|
"model": MODEL_NAME |
|
|
}, |
|
|
"mcp": mcp_status, |
|
|
"vector_store": vector_store.is_initialized() |
|
|
} |
|
|
except Exception as e: |
|
|
return JSONResponse( |
|
|
status_code=503, |
|
|
content={"status": "unhealthy", "error": str(e)} |
|
|
) |
|
|
|
|
|
async def stream_pipeline(request: PipelineRequest) -> AsyncGenerator[bytes, None]: |
|
|
""" |
|
|
Stream NDJSON events from pipeline |
|
|
|
|
|
Supports both dynamic (company_names) and legacy (company_ids) modes |
|
|
""" |
|
|
async for event in orchestrator.run_pipeline( |
|
|
company_ids=request.company_ids, |
|
|
company_names=request.company_names, |
|
|
use_seed_file=request.use_seed_file |
|
|
): |
|
|
|
|
|
yield (json.dumps(jsonable_encoder(event)) + "\n").encode() |
|
|
|
|
|
@app.post("/run") |
|
|
async def run_pipeline(request: PipelineRequest): |
|
|
""" |
|
|
Run the full pipeline with NDJSON streaming |
|
|
|
|
|
NEW: Accepts company_names for dynamic discovery |
|
|
LEGACY: Still supports company_ids for backwards compatibility |
|
|
|
|
|
Example (Dynamic): |
|
|
{"company_names": ["Shopify", "Stripe", "Zendesk"]} |
|
|
|
|
|
Example (Legacy): |
|
|
{"company_ids": ["acme", "techcorp"], "use_seed_file": true} |
|
|
""" |
|
|
return StreamingResponse( |
|
|
stream_pipeline(request), |
|
|
media_type="application/x-ndjson" |
|
|
) |
|
|
|
|
|
async def stream_writer_test(company_id: str) -> AsyncGenerator[bytes, None]: |
|
|
"""Stream only Writer agent output for testing""" |
|
|
from agents.writer import Writer |
|
|
|
|
|
|
|
|
store = mcp.get_store_client() |
|
|
company = await store.get_company(company_id) |
|
|
|
|
|
if not company: |
|
|
yield (json.dumps({"error": f"Company {company_id} not found"}) + "\n").encode() |
|
|
return |
|
|
|
|
|
|
|
|
prospect = Prospect( |
|
|
id=f"{company_id}_test", |
|
|
company=company, |
|
|
contacts=[], |
|
|
facts=[], |
|
|
fit_score=0.8, |
|
|
status="scored" |
|
|
) |
|
|
|
|
|
writer = Writer(mcp) |
|
|
async for event in writer.run_streaming(prospect): |
|
|
|
|
|
yield (json.dumps(jsonable_encoder(event)) + "\n").encode() |
|
|
|
|
|
@app.post("/writer/stream") |
|
|
async def writer_stream_test(request: WriterStreamRequest): |
|
|
"""Test endpoint for Writer streaming""" |
|
|
return StreamingResponse( |
|
|
stream_writer_test(request.company_id), |
|
|
media_type="application/x-ndjson" |
|
|
) |
|
|
|
|
|
@app.get("/prospects") |
|
|
async def list_prospects(): |
|
|
"""List all prospects with status and scores""" |
|
|
store = mcp.get_store_client() |
|
|
prospects = await store.list_prospects() |
|
|
return { |
|
|
"count": len(prospects), |
|
|
"prospects": [ |
|
|
{ |
|
|
"id": p.id, |
|
|
"company": p.company.name, |
|
|
"status": p.status, |
|
|
"fit_score": p.fit_score, |
|
|
"contacts": len(p.contacts), |
|
|
"facts": len(p.facts) |
|
|
} |
|
|
for p in prospects |
|
|
] |
|
|
} |
|
|
|
|
|
@app.get("/prospects/{prospect_id}") |
|
|
async def get_prospect(prospect_id: str): |
|
|
"""Get detailed prospect information""" |
|
|
store = mcp.get_store_client() |
|
|
prospect = await store.get_prospect(prospect_id) |
|
|
|
|
|
if not prospect: |
|
|
raise HTTPException(status_code=404, detail="Prospect not found") |
|
|
|
|
|
|
|
|
email_client = mcp.get_email_client() |
|
|
thread = None |
|
|
if prospect.thread_id: |
|
|
thread = await email_client.get_thread(prospect.id) |
|
|
|
|
|
return { |
|
|
"prospect": prospect.dict(), |
|
|
"thread": thread.dict() if thread else None |
|
|
} |
|
|
|
|
|
@app.get("/handoff/{prospect_id}") |
|
|
async def get_handoff(prospect_id: str): |
|
|
"""Get handoff packet for a prospect""" |
|
|
store = mcp.get_store_client() |
|
|
prospect = await store.get_prospect(prospect_id) |
|
|
|
|
|
if not prospect: |
|
|
raise HTTPException(status_code=404, detail="Prospect not found") |
|
|
|
|
|
if prospect.status != "ready_for_handoff": |
|
|
raise HTTPException(status_code=400, |
|
|
detail=f"Prospect not ready for handoff (status: {prospect.status})") |
|
|
|
|
|
|
|
|
email_client = mcp.get_email_client() |
|
|
thread = None |
|
|
if prospect.thread_id: |
|
|
thread = await email_client.get_thread(prospect.id) |
|
|
|
|
|
|
|
|
calendar_client = mcp.get_calendar_client() |
|
|
slots = await calendar_client.suggest_slots() |
|
|
|
|
|
packet = HandoffPacket( |
|
|
prospect=prospect, |
|
|
thread=thread, |
|
|
calendar_slots=slots, |
|
|
generated_at=datetime.utcnow() |
|
|
) |
|
|
|
|
|
return packet.dict() |
|
|
|
|
|
@app.post("/reset") |
|
|
async def reset_system(): |
|
|
"""Clear store, reload seeds, rebuild FAISS""" |
|
|
store = mcp.get_store_client() |
|
|
|
|
|
|
|
|
await store.clear_all() |
|
|
|
|
|
|
|
|
import json |
|
|
from app.config import COMPANIES_FILE |
|
|
|
|
|
with open(COMPANIES_FILE) as f: |
|
|
companies = json.load(f) |
|
|
|
|
|
for company_data in companies: |
|
|
await store.save_company(company_data) |
|
|
|
|
|
|
|
|
vector_store.rebuild_index() |
|
|
|
|
|
return { |
|
|
"status": "reset_complete", |
|
|
"companies_loaded": len(companies), |
|
|
"timestamp": datetime.utcnow().isoformat() |
|
|
} |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|
|