Spaces:
Sleeping
Sleeping
| import os | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers" | |
| os.environ["HF_HOME"] = "/tmp/huggingface" | |
| os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/tmp/sentence_transformers" | |
| os.environ["TORCH_HOME"] = "/tmp/torch" | |
| import json | |
| import asyncio | |
| from fastapi import FastAPI, HTTPException, UploadFile, File, WebSocket, WebSocketDisconnect | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import Optional, Dict, Set | |
| import chromadb | |
| from chromadb.config import Settings | |
| from sentence_transformers import SentenceTransformer | |
| # Import from autonomous agent | |
| from agent_langchain import ( | |
| process_with_agent, | |
| get_conversation_history, | |
| classify_ticket, | |
| call_routing, | |
| get_kb_collection, | |
| encoder, | |
| conversations | |
| ) | |
| app = FastAPI(title="Smart Helpdesk AI Agent - Autonomous + WebSocket") | |
| # CORS for frontend | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Update with your frontend URL in production | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Request Models | |
| class TicketRequest(BaseModel): | |
| text: str | |
| conversation_id: Optional[str] = None | |
| user_email: Optional[str] = None | |
| # WebSocket Connection Manager | |
| class ConnectionManager: | |
| def __init__(self): | |
| self.active_connections: Dict[str, WebSocket] = {} | |
| async def connect(self, websocket: WebSocket, conversation_id: str): | |
| await websocket.accept() | |
| self.active_connections[conversation_id] = websocket | |
| print(f"🔌 WebSocket connected: {conversation_id}") | |
| def disconnect(self, conversation_id: str): | |
| if conversation_id in self.active_connections: | |
| del self.active_connections[conversation_id] | |
| print(f"🔌 WebSocket disconnected: {conversation_id}") | |
| async def send_message(self, conversation_id: str, message: dict): | |
| if conversation_id in self.active_connections: | |
| try: | |
| await self.active_connections[conversation_id].send_json(message) | |
| except Exception as e: | |
| print(f"Error sending message: {e}") | |
| self.disconnect(conversation_id) | |
| manager = ConnectionManager() | |
| # Persistent Chroma settings | |
| CHROMA_PATH = "/tmp/chroma" | |
| COLLECTION_NAME = "knowledge_base" | |
| # ------------------------------- | |
| # KB Setup Endpoint | |
| # ------------------------------- | |
| async def setup_kb(kb_file: UploadFile = File(...)): | |
| """Upload and index knowledge base.""" | |
| try: | |
| content_bytes = await kb_file.read() | |
| data = json.loads(content_bytes) | |
| if not isinstance(data, list): | |
| raise HTTPException(status_code=400, detail="JSON must be a list of items.") | |
| print(f"📘 Loaded {len(data)} items from {kb_file.filename}") | |
| chroma_client = chromadb.PersistentClient( | |
| path=CHROMA_PATH, | |
| settings=Settings(anonymized_telemetry=False, allow_reset=True) | |
| ) | |
| collection = chroma_client.get_or_create_collection(COLLECTION_NAME) | |
| if collection.count() > 0: | |
| print(f"🧹 Clearing {collection.count()} existing records...") | |
| collection.delete(ids=collection.get()['ids']) | |
| texts, ids, metadatas = [], [], [] | |
| for i, item in enumerate(data): | |
| text = item.get("answer") or item.get("text") or item.get("content") or "" | |
| item_id = item.get("id") or str(i) | |
| category = item.get("category", "") | |
| if not text: | |
| print(f"⚠️ Skipping item {i} - no text content") | |
| continue | |
| combined_text = f"Category: {category}. {text}" if category else text | |
| texts.append(combined_text) | |
| ids.append(str(item_id)) | |
| metadatas.append({"id": str(item_id), "category": category, "original_index": i}) | |
| if not texts: | |
| raise HTTPException(status_code=400, detail="No valid text content found in JSON.") | |
| print("🧠 Generating embeddings...") | |
| embeddings = encoder.encode(texts, show_progress_bar=True).tolist() | |
| print("💾 Adding to ChromaDB...") | |
| collection.add(ids=ids, embeddings=embeddings, documents=texts, metadatas=metadatas) | |
| # Update global reference | |
| import agent_langchain | |
| agent_langchain.kb_collection = collection | |
| print(f"✅ Successfully added {collection.count()} records") | |
| return {"message": "Knowledge base initialized", "count": collection.count()} | |
| except json.JSONDecodeError: | |
| raise HTTPException(status_code=400, detail="Invalid JSON file.") | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=f"Setup failed: {str(e)}") | |
| # ------------------------------- | |
| # WebSocket Endpoint (REAL-TIME BIDIRECTIONAL) | |
| # ------------------------------- | |
| async def websocket_endpoint(websocket: WebSocket, conversation_id: str): | |
| """ | |
| WebSocket endpoint for real-time agent communication. | |
| Client sends: {"text": "My issue description", "user_email": "user@example.com"} | |
| Server streams: | |
| - {"type": "status", "message": "Agent is thinking..."} | |
| - {"type": "tool_use", "tool": "SearchKnowledgeBase", "input": "..."} | |
| - {"type": "response", "content": "Here's the solution..."} | |
| - {"type": "saved", "firestore_id": "abc123"} | |
| """ | |
| await manager.connect(websocket, conversation_id) | |
| try: | |
| while True: | |
| # Receive message from client | |
| data = await websocket.receive_json() | |
| user_message = data.get("text") | |
| user_email = data.get("user_email") | |
| if not user_message: | |
| await manager.send_message(conversation_id, { | |
| "type": "error", | |
| "message": "No text provided" | |
| }) | |
| continue | |
| # Send thinking status | |
| await manager.send_message(conversation_id, { | |
| "type": "status", | |
| "message": "🤔 Analyzing your request..." | |
| }) | |
| # Callback for streaming updates | |
| async def ws_callback(update): | |
| await manager.send_message(conversation_id, update) | |
| # Process with agent (in thread to avoid blocking) | |
| loop = asyncio.get_event_loop() | |
| result = await loop.run_in_executor( | |
| None, | |
| lambda: process_with_agent( | |
| user_message=user_message, | |
| conversation_id=conversation_id, | |
| user_email=user_email, | |
| callback=lambda msg: asyncio.run_coroutine_threadsafe(ws_callback(msg), loop) | |
| ) | |
| ) | |
| # Send final response | |
| await manager.send_message(conversation_id, { | |
| "type": "response", | |
| "conversation_id": result["conversation_id"], | |
| "content": result["response"], | |
| "status": result["status"], | |
| "ticket_info": result.get("ticket_info", {}), | |
| "message_count": result["message_count"], | |
| "firestore_id": result.get("firestore_id") | |
| }) | |
| except WebSocketDisconnect: | |
| manager.disconnect(conversation_id) | |
| print(f"Client disconnected: {conversation_id}") | |
| except Exception as e: | |
| print(f"WebSocket error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| try: | |
| await manager.send_message(conversation_id, { | |
| "type": "error", | |
| "message": str(e) | |
| }) | |
| except: | |
| pass | |
| manager.disconnect(conversation_id) | |
| # ------------------------------- | |
| # REST Endpoint (backward compatible) | |
| # ------------------------------- | |
| async def orchestrate_endpoint(ticket: TicketRequest): | |
| """ | |
| REST endpoint for agent interaction (backward compatible). | |
| Use WebSocket for real-time experience. | |
| """ | |
| try: | |
| result = process_with_agent( | |
| user_message=ticket.text, | |
| conversation_id=ticket.conversation_id, | |
| user_email=ticket.user_email | |
| ) | |
| return { | |
| "conversation_id": result["conversation_id"], | |
| "response": result["response"], | |
| "status": result["status"], | |
| "ticket_info": result.get("ticket_info", {}), | |
| "message_count": result["message_count"], | |
| "reasoning_trace": result.get("reasoning_trace", []), | |
| "firestore_id": result.get("firestore_id"), | |
| "instructions": { | |
| "websocket": f"ws://your-domain/ws/{result['conversation_id']}", | |
| "continue_conversation": "Include the conversation_id in your next request" | |
| } | |
| } | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=f"Agent failed: {str(e)}") | |
| # ------------------------------- | |
| # Get Conversation History | |
| # ------------------------------- | |
| async def get_conversation(conversation_id: str): | |
| """Retrieve full conversation history.""" | |
| conv = get_conversation_history(conversation_id) | |
| if not conv: | |
| raise HTTPException(status_code=404, detail="Conversation not found") | |
| return { | |
| "conversation_id": conversation_id, | |
| "messages": conv["messages"], | |
| "ticket_info": conv.get("ticket_info", {}), | |
| "status": conv.get("status", "unknown"), | |
| "created_at": conv["created_at"], | |
| "message_count": len(conv["messages"]) | |
| } | |
| # ------------------------------- | |
| # List Active Conversations | |
| # ------------------------------- | |
| async def list_conversations(): | |
| """List all active conversations.""" | |
| conv_list = [] | |
| for conv_id, conv_data in conversations.items(): | |
| conv_list.append({ | |
| "conversation_id": conv_id, | |
| "status": conv_data.get("status", "unknown"), | |
| "message_count": len(conv_data["messages"]), | |
| "created_at": conv_data["created_at"], | |
| "user_email": conv_data.get("user_email", "anonymous"), | |
| "last_message": conv_data["messages"][-1]["content"][:100] if conv_data["messages"] else None | |
| }) | |
| return { | |
| "total": len(conv_list), | |
| "conversations": sorted(conv_list, key=lambda x: x["created_at"], reverse=True) | |
| } | |
| # ------------------------------- | |
| # Individual Tool Endpoints (for testing) | |
| # ------------------------------- | |
| async def classify_endpoint(ticket: TicketRequest): | |
| """Test classification only.""" | |
| classification = classify_ticket(ticket.text) | |
| return {"classification": classification} | |
| async def route_endpoint(ticket: TicketRequest): | |
| """Test routing only.""" | |
| department = call_routing(ticket.text) | |
| return {"department": department} | |
| async def kb_query_endpoint(ticket: TicketRequest): | |
| """Test KB query only.""" | |
| collection = get_kb_collection() | |
| if not collection or collection.count() == 0: | |
| raise HTTPException(status_code=400, detail="KB not set up. Call /setup first.") | |
| try: | |
| query_embedding = encoder.encode([ticket.text])[0].tolist() | |
| result = collection.query( | |
| query_embeddings=[query_embedding], | |
| n_results=1, | |
| include=["documents", "distances", "metadatas"] | |
| ) | |
| if not result or not result.get('documents') or len(result['documents'][0]) == 0: | |
| return {"answer": "No relevant KB found.", "confidence": 0.0} | |
| best_doc = result['documents'][0][0] | |
| best_distance = result['distances'][0][0] if result.get('distances') else 1.0 | |
| confidence = max(0.0, 1.0 - (best_distance / 2.0)) | |
| return {"answer": best_doc, "confidence": round(float(confidence), 3)} | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=f"KB query failed: {str(e)}") | |
| # ------------------------------- | |
| # Health Check | |
| # ------------------------------- | |
| async def health(): | |
| collection = get_kb_collection() | |
| kb_status = "initialized" if collection and collection.count() > 0 else "not initialized" | |
| kb_count = collection.count() if collection else 0 | |
| return { | |
| "status": "ok", | |
| "kb_status": kb_status, | |
| "kb_records": kb_count, | |
| "active_conversations": len(conversations), | |
| "active_websockets": len(manager.active_connections), | |
| "agent_type": "Autonomous ReAct Agent with Gemini + WebSocket" | |
| } | |
| # ------------------------------- | |
| # Root endpoint | |
| # ------------------------------- | |
| async def root(): | |
| return { | |
| "message": "Smart Helpdesk AI Agent API", | |
| "endpoints": { | |
| "websocket": "/ws/{conversation_id}", | |
| "rest": "/orchestrate", | |
| "setup_kb": "/setup", | |
| "conversations": "/conversations", | |
| "health": "/health" | |
| }, | |
| "documentation": "/docs" | |
| } |