Spaces:
Runtime error
Runtime error
| import os | |
| import uuid | |
| import json | |
| import logging | |
| from typing import List, Dict, Any | |
| from dataclasses import dataclass | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from sentence_transformers import SentenceTransformer | |
| import faiss | |
| import numpy as np | |
| from groq import Groq | |
| # Logging setup | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Environment variables | |
| GROQ_API_KEY = os.getenv("GROQ_API_KEY") | |
| if not GROQ_API_KEY: | |
| raise ValueError("GROQ_API_KEY environment variable is required") | |
| # Load JTBD data | |
| with open("expanded_jtbd.json", "r") as f: | |
| JTBD_DATA: List[Dict[str, Any]] = json.load(f)["jobs_to_be_done"] | |
| class JTBDItem: | |
| name: str | |
| description: str | |
| business_function: str | |
| intent_type: str | |
| trigger_sources: List[str] | |
| index: int # Original index in list for reference | |
| # Global variables for vector store | |
| model = None | |
| index = None | |
| jtbd_items: List[JTBDItem] = [] | |
| def build_vector_store(): | |
| global model, index, jtbd_items | |
| # Initialize embedding model | |
| model = SentenceTransformer('all-MiniLM-L6-v2') | |
| # Prepare JTBD items and descriptions for embedding | |
| descriptions = [] | |
| jtbd_items = [] | |
| for idx, job in enumerate(JTBD_DATA): | |
| item = JTBDItem( | |
| name=job["name"], | |
| description=job["description"], | |
| business_function=job["business_function"], | |
| intent_type=job["intent_type"], | |
| trigger_sources=job["trigger_sources"], | |
| index=idx | |
| ) | |
| jtbd_items.append(item) | |
| descriptions.append(job["description"]) | |
| # Embed descriptions | |
| embeddings = model.encode(descriptions) | |
| embeddings = np.array(embeddings).astype('float32') | |
| # Build FAISS index | |
| dimension = embeddings.shape[1] | |
| index = faiss.IndexFlatL2(dimension) | |
| index.add(embeddings) | |
| logger.info(f"Vector store built with {len(jtbd_items)} JTBD items") | |
| # Build vector store on startup | |
| build_vector_store() | |
| # Initialize Groq client | |
| client = Groq(api_key=GROQ_API_KEY) | |
| # Pydantic models | |
| class ContextInput(BaseModel): | |
| context: str | |
| class JTBDOutput(BaseModel): | |
| request_id: str | |
| job_name: str | |
| department: str | |
| source: str | |
| intent_type: str | |
| confidence: float # Optional, based on LLM response | |
| # FastAPI app | |
| app = FastAPI(title="JTBD Identifier AI Agent", version="1.0.0") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def startup_event(): | |
| logger.info("Application started") | |
| async def identify_jtbd(input_data: ContextInput): | |
| try: | |
| # Generate unique request ID | |
| request_id = str(uuid.uuid4()) | |
| # Embed the input context | |
| context_embedding = model.encode([input_data.context]) | |
| context_embedding = np.array(context_embedding).astype('float32') | |
| # Retrieve top-k similar JTBDs (k=5 for efficiency) | |
| k = 5 | |
| distances, indices = index.search(context_embedding, k) | |
| # Get top-k JTBD items | |
| top_items = [jtbd_items[i] for i in indices[0]] | |
| # Prepare prompt for Groq | |
| top_descriptions = "\n\n".join([ | |
| f"Job {i+1}: {item.name}\nDescription: {item.description}\nDepartment: {item.business_function}\nIntent: {item.intent_type}" | |
| for i, item in enumerate(top_items) | |
| ]) | |
| prompt = f""" | |
| You are an expert at identifying Jobs To Be Done (JTBD) from email contexts. | |
| Given the following context from an email: | |
| "{input_data.context}" | |
| And these top candidate JTBDs: | |
| {top_descriptions} | |
| Identify the SINGLE BEST matching JTBD. Respond in JSON format only: | |
| {{ | |
| "job_name": "exact name of the job", | |
| "department": "exact business_function", | |
| "intent_type": "exact intent_type", | |
| "confidence": <float between 0.0 and 1.0, your estimated match confidence> | |
| }} | |
| If no good match, use the first one with confidence 0.0. | |
| """ | |
| # Call Groq LLM | |
| chat_completion = client.chat.completions.create( | |
| messages=[{"role": "user", "content": prompt}], | |
| model="llama3-8b-8192", # Or "mixtral-8x7b-32768" for better reasoning | |
| temperature=0.1, | |
| max_tokens=200, | |
| ) | |
| response = chat_completion.choices[0].message.content.strip() | |
| # Parse JSON response | |
| try: | |
| parsed = json.loads(response) | |
| job_name = parsed["job_name"] | |
| department = parsed["department"] | |
| intent_type = parsed["intent_type"] | |
| confidence = float(parsed["confidence"]) | |
| except (json.JSONDecodeError, KeyError, ValueError) as e: | |
| logger.warning(f"Failed to parse Groq response: {response}, error: {e}") | |
| # Fallback to top match | |
| top_match = top_items[0] | |
| job_name = top_match.name | |
| department = top_match.business_function | |
| intent_type = top_match.intent_type | |
| confidence = 0.5 # Default fallback | |
| # Fixed source as 'email' | |
| source = "email" | |
| logger.info(f"JTBD identified for request {request_id}: {job_name} in {department}") | |
| return JTBDOutput( | |
| request_id=request_id, | |
| job_name=job_name, | |
| department=department, | |
| source=source, | |
| intent_type=intent_type, | |
| confidence=confidence | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in identify_jtbd: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # Add this AT THE END of your app.py file | |
| from fastapi.responses import FileResponse | |
| async def read_index(): | |
| # This serves the HTML file when someone visits the root URL | |
| return FileResponse('index.html') | |
| async def health_check(): | |
| return {"status": "healthy"} | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |