import os os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers" os.environ["HF_HOME"] = "/tmp/huggingface" os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/tmp/sentence_transformers" os.environ["TORCH_HOME"] = "/tmp/torch" import json import asyncio from fastapi import FastAPI, HTTPException, UploadFile, File, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from typing import Optional, Dict, Set import chromadb from chromadb.config import Settings from sentence_transformers import SentenceTransformer # Import from autonomous agent from agent_langchain import ( process_with_agent, get_conversation_history, classify_ticket, call_routing, get_kb_collection, encoder, conversations ) app = FastAPI(title="Smart Helpdesk AI Agent - Autonomous + WebSocket") # CORS for frontend app.add_middleware( CORSMiddleware, allow_origins=["*"], # Update with your frontend URL in production allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Request Models class TicketRequest(BaseModel): text: str conversation_id: Optional[str] = None user_email: Optional[str] = None # WebSocket Connection Manager class ConnectionManager: def __init__(self): self.active_connections: Dict[str, WebSocket] = {} async def connect(self, websocket: WebSocket, conversation_id: str): await websocket.accept() self.active_connections[conversation_id] = websocket print(f"๐Ÿ”Œ WebSocket connected: {conversation_id}") def disconnect(self, conversation_id: str): if conversation_id in self.active_connections: del self.active_connections[conversation_id] print(f"๐Ÿ”Œ WebSocket disconnected: {conversation_id}") async def send_message(self, conversation_id: str, message: dict): if conversation_id in self.active_connections: try: await self.active_connections[conversation_id].send_json(message) except Exception as e: print(f"Error sending message: {e}") self.disconnect(conversation_id) manager = ConnectionManager() # Persistent Chroma settings CHROMA_PATH = "/tmp/chroma" COLLECTION_NAME = "knowledge_base" # ------------------------------- # KB Setup Endpoint # ------------------------------- @app.post("/setup") async def setup_kb(kb_file: UploadFile = File(...)): """Upload and index knowledge base.""" try: content_bytes = await kb_file.read() data = json.loads(content_bytes) if not isinstance(data, list): raise HTTPException(status_code=400, detail="JSON must be a list of items.") print(f"๐Ÿ“˜ Loaded {len(data)} items from {kb_file.filename}") chroma_client = chromadb.PersistentClient( path=CHROMA_PATH, settings=Settings(anonymized_telemetry=False, allow_reset=True) ) collection = chroma_client.get_or_create_collection(COLLECTION_NAME) if collection.count() > 0: print(f"๐Ÿงน Clearing {collection.count()} existing records...") collection.delete(ids=collection.get()['ids']) texts, ids, metadatas = [], [], [] for i, item in enumerate(data): text = item.get("answer") or item.get("text") or item.get("content") or "" item_id = item.get("id") or str(i) category = item.get("category", "") if not text: print(f"โš ๏ธ Skipping item {i} - no text content") continue combined_text = f"Category: {category}. {text}" if category else text texts.append(combined_text) ids.append(str(item_id)) metadatas.append({"id": str(item_id), "category": category, "original_index": i}) if not texts: raise HTTPException(status_code=400, detail="No valid text content found in JSON.") print("๐Ÿง  Generating embeddings...") embeddings = encoder.encode(texts, show_progress_bar=True).tolist() print("๐Ÿ’พ Adding to ChromaDB...") collection.add(ids=ids, embeddings=embeddings, documents=texts, metadatas=metadatas) # Update global reference import agent_langchain agent_langchain.kb_collection = collection print(f"โœ… Successfully added {collection.count()} records") return {"message": "Knowledge base initialized", "count": collection.count()} except json.JSONDecodeError: raise HTTPException(status_code=400, detail="Invalid JSON file.") except Exception as e: import traceback traceback.print_exc() raise HTTPException(status_code=500, detail=f"Setup failed: {str(e)}") # ------------------------------- # WebSocket Endpoint (REAL-TIME BIDIRECTIONAL) # ------------------------------- @app.websocket("/ws/{conversation_id}") async def websocket_endpoint(websocket: WebSocket, conversation_id: str): """ WebSocket endpoint for real-time agent communication. Client sends: {"text": "My issue description", "user_email": "user@example.com"} Server streams: - {"type": "status", "message": "Agent is thinking..."} - {"type": "tool_use", "tool": "SearchKnowledgeBase", "input": "..."} - {"type": "response", "content": "Here's the solution..."} - {"type": "saved", "firestore_id": "abc123"} """ await manager.connect(websocket, conversation_id) try: while True: # Receive message from client data = await websocket.receive_json() user_message = data.get("text") user_email = data.get("user_email") if not user_message: await manager.send_message(conversation_id, { "type": "error", "message": "No text provided" }) continue # Send thinking status await manager.send_message(conversation_id, { "type": "status", "message": "๐Ÿค” Analyzing your request..." }) # Callback for streaming updates async def ws_callback(update): await manager.send_message(conversation_id, update) # Process with agent (in thread to avoid blocking) loop = asyncio.get_event_loop() result = await loop.run_in_executor( None, lambda: process_with_agent( user_message=user_message, conversation_id=conversation_id, user_email=user_email, callback=lambda msg: asyncio.run_coroutine_threadsafe(ws_callback(msg), loop) ) ) # Send final response await manager.send_message(conversation_id, { "type": "response", "conversation_id": result["conversation_id"], "content": result["response"], "status": result["status"], "ticket_info": result.get("ticket_info", {}), "message_count": result["message_count"], "firestore_id": result.get("firestore_id") }) except WebSocketDisconnect: manager.disconnect(conversation_id) print(f"Client disconnected: {conversation_id}") except Exception as e: print(f"WebSocket error: {e}") import traceback traceback.print_exc() try: await manager.send_message(conversation_id, { "type": "error", "message": str(e) }) except: pass manager.disconnect(conversation_id) # ------------------------------- # REST Endpoint (backward compatible) # ------------------------------- @app.post("/orchestrate") async def orchestrate_endpoint(ticket: TicketRequest): """ REST endpoint for agent interaction (backward compatible). Use WebSocket for real-time experience. """ try: result = process_with_agent( user_message=ticket.text, conversation_id=ticket.conversation_id, user_email=ticket.user_email ) return { "conversation_id": result["conversation_id"], "response": result["response"], "status": result["status"], "ticket_info": result.get("ticket_info", {}), "message_count": result["message_count"], "reasoning_trace": result.get("reasoning_trace", []), "firestore_id": result.get("firestore_id"), "instructions": { "websocket": f"ws://your-domain/ws/{result['conversation_id']}", "continue_conversation": "Include the conversation_id in your next request" } } except Exception as e: import traceback traceback.print_exc() raise HTTPException(status_code=500, detail=f"Agent failed: {str(e)}") # ------------------------------- # Get Conversation History # ------------------------------- @app.get("/conversation/{conversation_id}") async def get_conversation(conversation_id: str): """Retrieve full conversation history.""" conv = get_conversation_history(conversation_id) if not conv: raise HTTPException(status_code=404, detail="Conversation not found") return { "conversation_id": conversation_id, "messages": conv["messages"], "ticket_info": conv.get("ticket_info", {}), "status": conv.get("status", "unknown"), "created_at": conv["created_at"], "message_count": len(conv["messages"]) } # ------------------------------- # List Active Conversations # ------------------------------- @app.get("/conversations") async def list_conversations(): """List all active conversations.""" conv_list = [] for conv_id, conv_data in conversations.items(): conv_list.append({ "conversation_id": conv_id, "status": conv_data.get("status", "unknown"), "message_count": len(conv_data["messages"]), "created_at": conv_data["created_at"], "user_email": conv_data.get("user_email", "anonymous"), "last_message": conv_data["messages"][-1]["content"][:100] if conv_data["messages"] else None }) return { "total": len(conv_list), "conversations": sorted(conv_list, key=lambda x: x["created_at"], reverse=True) } # ------------------------------- # Individual Tool Endpoints (for testing) # ------------------------------- @app.post("/classify") async def classify_endpoint(ticket: TicketRequest): """Test classification only.""" classification = classify_ticket(ticket.text) return {"classification": classification} @app.post("/route") async def route_endpoint(ticket: TicketRequest): """Test routing only.""" department = call_routing(ticket.text) return {"department": department} @app.post("/kb_query") async def kb_query_endpoint(ticket: TicketRequest): """Test KB query only.""" collection = get_kb_collection() if not collection or collection.count() == 0: raise HTTPException(status_code=400, detail="KB not set up. Call /setup first.") try: query_embedding = encoder.encode([ticket.text])[0].tolist() result = collection.query( query_embeddings=[query_embedding], n_results=1, include=["documents", "distances", "metadatas"] ) if not result or not result.get('documents') or len(result['documents'][0]) == 0: return {"answer": "No relevant KB found.", "confidence": 0.0} best_doc = result['documents'][0][0] best_distance = result['distances'][0][0] if result.get('distances') else 1.0 confidence = max(0.0, 1.0 - (best_distance / 2.0)) return {"answer": best_doc, "confidence": round(float(confidence), 3)} except Exception as e: import traceback traceback.print_exc() raise HTTPException(status_code=500, detail=f"KB query failed: {str(e)}") # ------------------------------- # Health Check # ------------------------------- @app.get("/health") async def health(): collection = get_kb_collection() kb_status = "initialized" if collection and collection.count() > 0 else "not initialized" kb_count = collection.count() if collection else 0 return { "status": "ok", "kb_status": kb_status, "kb_records": kb_count, "active_conversations": len(conversations), "active_websockets": len(manager.active_connections), "agent_type": "Autonomous ReAct Agent with Gemini + WebSocket" } # ------------------------------- # Root endpoint # ------------------------------- @app.get("/") async def root(): return { "message": "Smart Helpdesk AI Agent API", "endpoints": { "websocket": "/ws/{conversation_id}", "rest": "/orchestrate", "setup_kb": "/setup", "conversations": "/conversations", "health": "/health" }, "documentation": "/docs" }