Spaces:
Sleeping
Sleeping
File size: 9,932 Bytes
14e52d0 a4f38f8 fa6d547 a4f38f8 14e52d0 fa6d547 14e52d0 a4f38f8 14e52d0 a4f38f8 14e52d0 fa6d547 a4f38f8 14e52d0 a4f38f8 fa6d547 14e52d0 fa6d547 a4f38f8 14e52d0 a4f38f8 14e52d0 fa6d547 14e52d0 a4f38f8 fa6d547 a4f38f8 fa6d547 14e52d0 fa6d547 14e52d0 a4f38f8 14e52d0 a4f38f8 14e52d0 b9cc326 14e52d0 a4f38f8 b9cc326 fa6d547 14e52d0 0fdf847 fa6d547 0fdf847 fa6d547 b9cc326 a4f38f8 fa6d547 aeb80f5 b9cc326 ebdd08a fa6d547 a4f38f8 fa6d547 b9cc326 a4f38f8 e2cb069 b9cc326 fa6d547 b9cc326 fa6d547 14e52d0 b9cc326 14e52d0 a4f38f8 fa6d547 14e52d0 fa6d547 a4f38f8 fa6d547 a4f38f8 fa6d547 14e52d0 fa6d547 14e52d0 a4f38f8 14e52d0 fa6d547 14e52d0 fa6d547 a4f38f8 fa6d547 a4f38f8 fa6d547 a4f38f8 fa6d547 3e1d947 fa6d547 a4f38f8 fa6d547 a4f38f8 fa6d547 a4f38f8 fa6d547 e2cb069 fa6d547 a4f38f8 fa6d547 aaab73f fa6d547 a4f38f8 fa6d547 a4f38f8 fa6d547 a4f38f8 3e1d947 fa6d547 a4f38f8 fa6d547 3e1d947 fa6d547 14e52d0 e2cb069 fa6d547 3e1d947 fa6d547 e2cb069 fa6d547 ff7dd6c fa6d547 e2cb069 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 | import google.generativeai as genai
import google.api_core.exceptions
from fastapi import FastAPI, HTTPException, BackgroundTasks
from pydantic import BaseModel, Field
import os
import json
import time
from sentence_transformers import SentenceTransformer, util
import torch
import asyncio
# Removed: import asyncio.to_thread # Use this for running blocking code in async
from uuid import uuid4
# --- RAG Memory (Global for the service) ---
try:
embed_model = SentenceTransformer("all-MiniLM-L6-v2")
print("SentenceTransformer model loaded successfully.")
except Exception as e:
print(f"CRITICAL: Failed to load SentenceTransformer model: {e}")
embed_model = None
memory_store = [] # In-memory store for RAG
# -------------------------------------------
# --- Gemini Configuration ---
try:
genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
if not os.getenv("GOOGLE_API_KEY"):
raise ValueError("GOOGLE_API_KEY environment variable not set.")
print("Google Gemini client initialized successfully.")
except Exception as e:
print(f"CRITICAL: Failed to initialize Gemini client: {e}")
genai = None
# Define the exact JSON structure we want Gemini to return
class TicketResponse(BaseModel):
decision: str = Field(description="The classification category.")
reason: str = Field(description="A brief reason for the decision.")
next_actions: list[str] = Field(description="A list of next actions.")
TICKET_SCHEMA = {
"type": "OBJECT",
"properties": {
"decision": {"type": "STRING"},
"reason": {"type": "STRING"},
"next_actions": {
"type": "ARRAY",
"items": {"type": "STRING"}
}
},
"required": ["decision", "reason", "next_actions"]
}
# -----------------------------
app = FastAPI(title="Async Ticket Service (RAG + Gemini)")
class Ticket(BaseModel):
channel: str
severity: str
summary: str
# In-memory queue and result store
ticket_queue = asyncio.Queue()
results_store = {}
# --- RAG Functions (must be sync, will be called in a thread) ---
def add_to_memory(ticket_text, response_json):
if not embed_model:
print("No embed model, skipping add_to_memory.")
return
try:
# Note: encode() is a blocking CPU-bound operation
embedding = embed_model.encode(ticket_text, convert_to_tensor=True)
memory_store.append({
"text": ticket_text,
"embedding": embedding,
"response": response_json # Store the full JSON string
})
print(f"Added to async memory. Memory size: {len(memory_store)}")
except Exception as e:
print(f"Error adding to memory: {e}")
# --- UPDATED retrieve_context function ---
def retrieve_context(query_text, top_k=2):
if not embed_model or not memory_store:
print("No memory or embed model, returning empty context.")
return "No relevant past cases found."
try:
# Encode the query
query_emb = embed_model.encode(query_text, convert_to_tensor=True)
# Calculate similarities
sims = [util.cos_sim(query_emb, item["embedding"]).item() for item in memory_store]
# Log the raw scores for debugging
print(f"Raw similarity scores for '{query_text}': {sims}")
# Get ALL indices sorted by similarity (not just top_k)
all_indices_sorted = sorted(range(len(sims)), key=lambda i: sims[i], reverse=True)
# Filter FIRST, then take top_k from filtered results
# This ensures we only consider truly relevant cases
relevant_indices = [
i for i in all_indices_sorted
if sims[i] >= 0.70 and sims[i] < 0.99 # Strict similarity threshold
][:top_k] # Take only top_k AFTER filtering
if not relevant_indices:
print(f"No context found above 90% similarity threshold. Best score was: {max(sims) if sims else 'N/A'}")
return "No relevant past cases found."
# Build context string with similarity scores for transparency
context_parts = []
for i in relevant_indices:
context_parts.append(
f"Past Ticket (similarity: {sims[i]:.2f}): {memory_store[i]['text']}\n"
f"Past Response: {memory_store[i]['response']}"
)
context = "\n\n".join(context_parts)
print(f"Retrieved {len(relevant_indices)} relevant context(s) for async prompt")
return context
except Exception as e:
print(f"Error retrieving context: {e}")
return "Error retrieving context."
# --- END UPDATED ---
# --- THIS IS THE NEW, BETTER PROMPT ---
def build_rag_prompt(ticket: Ticket, context: str) -> str:
return f"""
You are an expert banking support assistant. Your job is to classify a new ticket.
You must choose one of three categories:
1. AI Code Patch: Select this for technical bugs, API errors, code-related problems, or system failures.
2. Vibe Workflow: Select this for standard customer requests (e.g., "unblock my card," "payment failed," "reset password," or general banking inquiries).
3. Unknown: Select this for random, vague, or irrelevant tickets (e.g., messages like "hi", "hello", or non-descriptive/empty queries).
Use the following past cases as context if relevant:
---
{context}
---
Important Instructions:
- If the retrieval context is irrelevant or noisy, ignore it and focus only on the provided ticket information.
- Do NOT guess if any information is missing or unclear.
- If information is insufficient, respond with the category "Unknown" with a clear reason.
Now classify this new ticket. Return only the valid JSON response.
New Ticket:
Channel: {ticket.channel}
Severity: {ticket.severity}
Summary: {ticket.summary}
"""
# ----------------------------------------
async def classify_ticket_with_gemini_async(ticket: Ticket):
if not genai:
print("Worker error: Gemini client not initialized.")
return {"error": "Gemini client not initialized"}, "Gemini client not initialized", 0.0 # Return 3 values
try:
# 1. Retrieve context (blocking, run in thread)
context_str = await asyncio.to_thread(retrieve_context, ticket.summary)
# 2. Build the prompt
prompt = build_rag_prompt(ticket, context_str)
# 3. Call Gemini (blocking, run in thread)
def gemini_call():
model = genai.GenerativeModel("gemini-2.5-flash")
response = model.generate_content(
prompt,
generation_config=genai.GenerationConfig(
response_mime_type="application/json",
response_schema=TICKET_SCHEMA
)
)
return response.text
# --- TIMER FIX: Start timer *just* before the API call thread ---
start_time = time.time()
result_json_str = await asyncio.to_thread(gemini_call)
# --- TIMER FIX: End timer *immediately* after the API call thread ---
processing_time = time.time() - start_time
print(f"Gemini API processing time (async): {processing_time:.2f}s")
# 4. Parse the JSON
result_data = json.loads(result_json_str)
# 5. Add to memory *after* (blocking, run in thread)
await asyncio.to_thread(add_to_memory, ticket.summary, result_json_str)
# Add the *correct* processing time and context to the result
result_data["processing_time"] = processing_time
result_data["retrieved_context"] = context_str
# --- FIX: Return 3 values on success ---
return result_data, None, processing_time
except Exception as e:
print(f"!!! Unexpected Error in async classify_ticket (Gemini): {e}")
# Return 3 values on error
return {"error": str(e)}, str(e), 0.0
# Worker function
async def worker(worker_id: int):
print(f"Worker {worker_id} starting...")
if not genai or not embed_model:
print(f"Worker {worker_id}: Client not initialized. Worker stopping.")
return
while True:
try:
ticket_id, ticket = await ticket_queue.get()
print(f"Worker {worker_id} processing ticket {ticket_id}: {ticket.summary}")
results_store[ticket_id] = {"status": "processing"}
try:
# Unpack the tuple (now always 3 values)
result_data, error_detail, processing_time = await classify_ticket_with_gemini_async(ticket)
if error_detail:
results_store[ticket_id] = {"status": "error", "detail": error_detail}
else:
results_store[ticket_id] = {"status": "completed", "result": result_data}
except Exception as e:
print(f"Worker {worker_id} error processing {ticket_id}: {e}")
results_store[ticket_id] = {"status": "error", "detail": str(e)}
finally:
ticket_queue.task_done()
except Exception as e:
print(f"Worker {worker_id} critical error: {e}")
await asyncio.sleep(1) # Prevent tight loop on critical error
@app.on_event("startup")
async def startup_event():
print("Starting 3 workers...")
for i in range(3):
asyncio.create_task(worker(i))
# Submit ticket (non-blocking)
@app.post("/async_ticket")
async def async_ticket(ticket: Ticket):
ticket_id = str(uuid4())
await ticket_queue.put((ticket_id, ticket))
results_store[ticket_id] = {"status": "queued"}
return {"ticket_id": ticket_id, "status": "queued"}
# Get ticket result
@app.get("/result/{ticket_id}")
async def get_result(ticket_id: str):
result = results_store.get(ticket_id, {"status": "pending"})
return result
|