import google.generativeai as genai import google.api_core.exceptions from fastapi import FastAPI, HTTPException, BackgroundTasks from pydantic import BaseModel, Field import os import json import time from sentence_transformers import SentenceTransformer, util import torch import asyncio # Removed: import asyncio.to_thread # Use this for running blocking code in async from uuid import uuid4 # --- RAG Memory (Global for the service) --- try: embed_model = SentenceTransformer("all-MiniLM-L6-v2") print("SentenceTransformer model loaded successfully.") except Exception as e: print(f"CRITICAL: Failed to load SentenceTransformer model: {e}") embed_model = None memory_store = [] # In-memory store for RAG # ------------------------------------------- # --- Gemini Configuration --- try: genai.configure(api_key=os.getenv("GOOGLE_API_KEY")) if not os.getenv("GOOGLE_API_KEY"): raise ValueError("GOOGLE_API_KEY environment variable not set.") print("Google Gemini client initialized successfully.") except Exception as e: print(f"CRITICAL: Failed to initialize Gemini client: {e}") genai = None # Define the exact JSON structure we want Gemini to return class TicketResponse(BaseModel): decision: str = Field(description="The classification category.") reason: str = Field(description="A brief reason for the decision.") next_actions: list[str] = Field(description="A list of next actions.") TICKET_SCHEMA = { "type": "OBJECT", "properties": { "decision": {"type": "STRING"}, "reason": {"type": "STRING"}, "next_actions": { "type": "ARRAY", "items": {"type": "STRING"} } }, "required": ["decision", "reason", "next_actions"] } # ----------------------------- app = FastAPI(title="Async Ticket Service (RAG + Gemini)") class Ticket(BaseModel): channel: str severity: str summary: str # In-memory queue and result store ticket_queue = asyncio.Queue() results_store = {} # --- RAG Functions (must be sync, will be called in a thread) --- def add_to_memory(ticket_text, response_json): if not embed_model: print("No embed model, skipping add_to_memory.") return try: # Note: encode() is a blocking CPU-bound operation embedding = embed_model.encode(ticket_text, convert_to_tensor=True) memory_store.append({ "text": ticket_text, "embedding": embedding, "response": response_json # Store the full JSON string }) print(f"Added to async memory. Memory size: {len(memory_store)}") except Exception as e: print(f"Error adding to memory: {e}") # --- UPDATED retrieve_context function --- def retrieve_context(query_text, top_k=2): if not embed_model or not memory_store: print("No memory or embed model, returning empty context.") return "No relevant past cases found." try: # Encode the query query_emb = embed_model.encode(query_text, convert_to_tensor=True) # Calculate similarities sims = [util.cos_sim(query_emb, item["embedding"]).item() for item in memory_store] # Log the raw scores for debugging print(f"Raw similarity scores for '{query_text}': {sims}") # Get ALL indices sorted by similarity (not just top_k) all_indices_sorted = sorted(range(len(sims)), key=lambda i: sims[i], reverse=True) # Filter FIRST, then take top_k from filtered results # This ensures we only consider truly relevant cases relevant_indices = [ i for i in all_indices_sorted if sims[i] >= 0.70 and sims[i] < 0.99 # Strict similarity threshold ][:top_k] # Take only top_k AFTER filtering if not relevant_indices: print(f"No context found above 90% similarity threshold. Best score was: {max(sims) if sims else 'N/A'}") return "No relevant past cases found." # Build context string with similarity scores for transparency context_parts = [] for i in relevant_indices: context_parts.append( f"Past Ticket (similarity: {sims[i]:.2f}): {memory_store[i]['text']}\n" f"Past Response: {memory_store[i]['response']}" ) context = "\n\n".join(context_parts) print(f"Retrieved {len(relevant_indices)} relevant context(s) for async prompt") return context except Exception as e: print(f"Error retrieving context: {e}") return "Error retrieving context." # --- END UPDATED --- # --- THIS IS THE NEW, BETTER PROMPT --- def build_rag_prompt(ticket: Ticket, context: str) -> str: return f""" You are an expert banking support assistant. Your job is to classify a new ticket. You must choose one of three categories: 1. AI Code Patch: Select this for technical bugs, API errors, code-related problems, or system failures. 2. Vibe Workflow: Select this for standard customer requests (e.g., "unblock my card," "payment failed," "reset password," or general banking inquiries). 3. Unknown: Select this for random, vague, or irrelevant tickets (e.g., messages like "hi", "hello", or non-descriptive/empty queries). Use the following past cases as context if relevant: --- {context} --- Important Instructions: - If the retrieval context is irrelevant or noisy, ignore it and focus only on the provided ticket information. - Do NOT guess if any information is missing or unclear. - If information is insufficient, respond with the category "Unknown" with a clear reason. Now classify this new ticket. Return only the valid JSON response. New Ticket: Channel: {ticket.channel} Severity: {ticket.severity} Summary: {ticket.summary} """ # ---------------------------------------- async def classify_ticket_with_gemini_async(ticket: Ticket): if not genai: print("Worker error: Gemini client not initialized.") return {"error": "Gemini client not initialized"}, "Gemini client not initialized", 0.0 # Return 3 values try: # 1. Retrieve context (blocking, run in thread) context_str = await asyncio.to_thread(retrieve_context, ticket.summary) # 2. Build the prompt prompt = build_rag_prompt(ticket, context_str) # 3. Call Gemini (blocking, run in thread) def gemini_call(): model = genai.GenerativeModel("gemini-2.5-flash") response = model.generate_content( prompt, generation_config=genai.GenerationConfig( response_mime_type="application/json", response_schema=TICKET_SCHEMA ) ) return response.text # --- TIMER FIX: Start timer *just* before the API call thread --- start_time = time.time() result_json_str = await asyncio.to_thread(gemini_call) # --- TIMER FIX: End timer *immediately* after the API call thread --- processing_time = time.time() - start_time print(f"Gemini API processing time (async): {processing_time:.2f}s") # 4. Parse the JSON result_data = json.loads(result_json_str) # 5. Add to memory *after* (blocking, run in thread) await asyncio.to_thread(add_to_memory, ticket.summary, result_json_str) # Add the *correct* processing time and context to the result result_data["processing_time"] = processing_time result_data["retrieved_context"] = context_str # --- FIX: Return 3 values on success --- return result_data, None, processing_time except Exception as e: print(f"!!! Unexpected Error in async classify_ticket (Gemini): {e}") # Return 3 values on error return {"error": str(e)}, str(e), 0.0 # Worker function async def worker(worker_id: int): print(f"Worker {worker_id} starting...") if not genai or not embed_model: print(f"Worker {worker_id}: Client not initialized. Worker stopping.") return while True: try: ticket_id, ticket = await ticket_queue.get() print(f"Worker {worker_id} processing ticket {ticket_id}: {ticket.summary}") results_store[ticket_id] = {"status": "processing"} try: # Unpack the tuple (now always 3 values) result_data, error_detail, processing_time = await classify_ticket_with_gemini_async(ticket) if error_detail: results_store[ticket_id] = {"status": "error", "detail": error_detail} else: results_store[ticket_id] = {"status": "completed", "result": result_data} except Exception as e: print(f"Worker {worker_id} error processing {ticket_id}: {e}") results_store[ticket_id] = {"status": "error", "detail": str(e)} finally: ticket_queue.task_done() except Exception as e: print(f"Worker {worker_id} critical error: {e}") await asyncio.sleep(1) # Prevent tight loop on critical error @app.on_event("startup") async def startup_event(): print("Starting 3 workers...") for i in range(3): asyncio.create_task(worker(i)) # Submit ticket (non-blocking) @app.post("/async_ticket") async def async_ticket(ticket: Ticket): ticket_id = str(uuid4()) await ticket_queue.put((ticket_id, ticket)) results_store[ticket_id] = {"status": "queued"} return {"ticket_id": ticket_id, "status": "queued"} # Get ticket result @app.get("/result/{ticket_id}") async def get_result(ticket_id: str): result = results_store.get(ticket_id, {"status": "pending"}) return result