File size: 6,643 Bytes
8bab08d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
# file: app/main.py
import json
from datetime import datetime
from typing import AsyncGenerator
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.encoders import jsonable_encoder
from app.schema import PipelineRequest, WriterStreamRequest, Prospect, HandoffPacket
from app.orchestrator import Orchestrator
from app.config import MODEL_NAME, HF_API_TOKEN
from app.logging_utils import setup_logging
from mcp.registry import MCPRegistry
from vector.store import VectorStore
import requests

setup_logging()

app = FastAPI(title="CX AI Agent", version="1.0.0")
orchestrator = Orchestrator()
mcp = MCPRegistry()
vector_store = VectorStore()

@app.on_event("startup")
async def startup():
    """Initialize connections on startup"""
    await mcp.connect()

@app.get("/health")
async def health():
    """Health check with HF API connectivity test"""
    try:
        # Check HF API
        hf_ok = bool(HF_API_TOKEN)

        # Check MCP servers
        mcp_status = await mcp.health_check()

        return {
            "status": "healthy",
            "timestamp": datetime.utcnow().isoformat(),
            "hf_inference": {
                "configured": hf_ok,
                "model": MODEL_NAME
            },
            "mcp": mcp_status,
            "vector_store": vector_store.is_initialized()
        }
    except Exception as e:
        return JSONResponse(
            status_code=503,
            content={"status": "unhealthy", "error": str(e)}
        )

async def stream_pipeline(request: PipelineRequest) -> AsyncGenerator[bytes, None]:
    """
    Stream NDJSON events from pipeline

    Supports both dynamic (company_names) and legacy (company_ids) modes
    """
    async for event in orchestrator.run_pipeline(
        company_ids=request.company_ids,
        company_names=request.company_names,
        use_seed_file=request.use_seed_file
    ):
        # Ensure nested Pydantic models (e.g., Prospect) are JSON-serializable
        yield (json.dumps(jsonable_encoder(event)) + "\n").encode()

@app.post("/run")
async def run_pipeline(request: PipelineRequest):
    """
    Run the full pipeline with NDJSON streaming

    NEW: Accepts company_names for dynamic discovery
    LEGACY: Still supports company_ids for backwards compatibility

    Example (Dynamic):
    {"company_names": ["Shopify", "Stripe", "Zendesk"]}

    Example (Legacy):
    {"company_ids": ["acme", "techcorp"], "use_seed_file": true}
    """
    return StreamingResponse(
        stream_pipeline(request),
        media_type="application/x-ndjson"
    )

async def stream_writer_test(company_id: str) -> AsyncGenerator[bytes, None]:
    """Stream only Writer agent output for testing"""
    from agents.writer import Writer
    
    # Get company from store
    store = mcp.get_store_client()
    company = await store.get_company(company_id)
    
    if not company:
        yield (json.dumps({"error": f"Company {company_id} not found"}) + "\n").encode()
        return
    
    # Create a test prospect
    prospect = Prospect(
        id=f"{company_id}_test",
        company=company,
        contacts=[],
        facts=[],
        fit_score=0.8,
        status="scored"
    )
    
    writer = Writer(mcp)
    async for event in writer.run_streaming(prospect):
        # Ensure nested Pydantic models (e.g., Prospect) are JSON-serializable
        yield (json.dumps(jsonable_encoder(event)) + "\n").encode()

@app.post("/writer/stream")
async def writer_stream_test(request: WriterStreamRequest):
    """Test endpoint for Writer streaming"""
    return StreamingResponse(
        stream_writer_test(request.company_id),
        media_type="application/x-ndjson"
    )

@app.get("/prospects")
async def list_prospects():
    """List all prospects with status and scores"""
    store = mcp.get_store_client()
    prospects = await store.list_prospects()
    return {
        "count": len(prospects),
        "prospects": [
            {
                "id": p.id,
                "company": p.company.name,
                "status": p.status,
                "fit_score": p.fit_score,
                "contacts": len(p.contacts),
                "facts": len(p.facts)
            }
            for p in prospects
        ]
    }

@app.get("/prospects/{prospect_id}")
async def get_prospect(prospect_id: str):
    """Get detailed prospect information"""
    store = mcp.get_store_client()
    prospect = await store.get_prospect(prospect_id)
    
    if not prospect:
        raise HTTPException(status_code=404, detail="Prospect not found")
    
    # Get thread if exists
    email_client = mcp.get_email_client()
    thread = None
    if prospect.thread_id:
        thread = await email_client.get_thread(prospect.id)
    
    return {
        "prospect": prospect.dict(),
        "thread": thread.dict() if thread else None
    }

@app.get("/handoff/{prospect_id}")
async def get_handoff(prospect_id: str):
    """Get handoff packet for a prospect"""
    store = mcp.get_store_client()
    prospect = await store.get_prospect(prospect_id)
    
    if not prospect:
        raise HTTPException(status_code=404, detail="Prospect not found")
    
    if prospect.status != "ready_for_handoff":
        raise HTTPException(status_code=400, 
                          detail=f"Prospect not ready for handoff (status: {prospect.status})")
    
    # Get thread
    email_client = mcp.get_email_client()
    thread = None
    if prospect.thread_id:
        thread = await email_client.get_thread(prospect.id)
    
    # Get calendar slots
    calendar_client = mcp.get_calendar_client()
    slots = await calendar_client.suggest_slots()
    
    packet = HandoffPacket(
        prospect=prospect,
        thread=thread,
        calendar_slots=slots,
        generated_at=datetime.utcnow()
    )
    
    return packet.dict()

@app.post("/reset")
async def reset_system():
    """Clear store, reload seeds, rebuild FAISS"""
    store = mcp.get_store_client()
    
    # Clear all data
    await store.clear_all()
    
    # Reload seed companies
    import json
    from app.config import COMPANIES_FILE
    
    with open(COMPANIES_FILE) as f:
        companies = json.load(f)
    
    for company_data in companies:
        await store.save_company(company_data)
    
    # Rebuild vector index
    vector_store.rebuild_index()
    
    return {
        "status": "reset_complete",
        "companies_loaded": len(companies),
        "timestamp": datetime.utcnow().isoformat()
    }

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)