preBot / fapi.py
wearevenom's picture
Update fapi.py
6f9f426 verified
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from agents import coordinator
from google.adk.sessions import InMemorySessionService
from constants import INSTITUTE_MAPPING, BRANCH_MAPPING
from google.adk.tools import google_search
from google.adk.runners import Runner
from google.genai import types # Add this import for Content and Part
from dotenv import load_dotenv
import os
import re
import datetime
# Load environment variables
load_dotenv()
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
app = FastAPI(
title="PreBot College Counselor API",
description="AI-powered college counseling system with multi-agent architecture",
version="1.0.0",
docs_url="/docs",
redoc_url="/redoc"
)
# Enable CORS for all origins (adjust for production)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Use a shared session service instance
session_service = InMemorySessionService()
class ChatRequest(BaseModel):
user_id: str
session_id: str
question: str
class ChatResponse(BaseModel):
session_id: str
answer: str
def preprocess_query(query: str) -> str:
sorted_institutes = sorted(INSTITUTE_MAPPING.keys(), key=len, reverse=True)
for key in sorted_institutes:
pattern = rf'\b{re.escape(key)}\b'
query = re.sub(pattern, INSTITUTE_MAPPING[key][0], query, flags=re.IGNORECASE)
for key, full_name in BRANCH_MAPPING.items():
pattern = rf'\b{re.escape(key)}\b'
query = re.sub(pattern, full_name, query, flags=re.IGNORECASE)
return query
@app.options("/chat")
async def chat_options():
return JSONResponse(
content={"message": "OK"},
headers={
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "POST, OPTIONS",
"Access-Control-Allow-Headers": "*",
}
)
@app.post("/chat", response_model=ChatResponse)
async def chat_endpoint(req: ChatRequest):
try:
print(f"Received request - User ID: {req.user_id}, Session ID: {req.session_id}")
print(f"Question: {req.question}")
# Check if session exists, create if not (methods are NOT async for InMemorySessionService)
print("Checking for existing session...")
try:
existing_session = await session_service.get_session(
app_name="coordinator_agent",
user_id=req.user_id,
session_id=req.session_id
)
except:
existing_session = None
if not existing_session:
print("Creating new session...")
try:
await session_service.create_session(
app_name="coordinator_agent",
user_id=req.user_id,
session_id=req.session_id
)
except Exception as session_error:
print(f"Session creation error: {session_error}")
else:
print("Using existing session")
# Use the shared session service for the Runner
print("Creating runner...")
runner = Runner(
agent=coordinator,
app_name="coordinator_agent",
session_service=session_service # Use the shared session service
)
# Create properly formatted message using Google ADK types
print("Processing query...")
# Read last agent from session metadata (if available) so coordinator can honor follow-ups
last_agent_name = None
try:
if existing_session and isinstance(existing_session, dict):
# Some session implementations return a dict with metadata
metadata = existing_session.get("metadata") or existing_session.get("meta") or {}
if isinstance(metadata, dict):
last_agent_name = metadata.get("last_agent")
except Exception as meta_err:
print(f"Could not read session metadata: {meta_err}")
processed_query = preprocess_query(req.question)
# If we have a last_agent, prepend it in the agreed format so the coordinator can use it
if last_agent_name:
processed_query = f"LAST_AGENT: {last_agent_name}\n" + processed_query
print(f"Processed query: {processed_query}")
user_msg = types.Content(role="user", parts=[types.Part(text=processed_query)])
print("Running agent...")
agent_response = runner.run(
user_id=req.user_id,
session_id=req.session_id,
new_message=user_msg,
)
# Process the generator response to extract the final answer
print(f"Agent response type: {type(agent_response)}")
reply_text = ""
if hasattr(agent_response, '__iter__') and not isinstance(agent_response, str):
print("Processing iterable response...")
for event in agent_response:
print(f"Processing event: {event}")
# Try multiple ways to extract text from event
if hasattr(event, 'is_final_response') and event.is_final_response():
if hasattr(event, 'content') and hasattr(event.content, 'parts'):
for part in event.content.parts:
if hasattr(part, 'text') and part.text:
reply_text = part.text
break
if reply_text:
break
elif hasattr(event, 'text'):
reply_text = event.text
break
# Also try to get content from any event that has text
if hasattr(event, 'content'):
if hasattr(event.content, 'parts'):
for part in event.content.parts:
if hasattr(part, 'text') and part.text:
reply_text += part.text + " "
elif hasattr(event.content, 'text'):
reply_text += event.content.text + " "
elif hasattr(event, 'text'):
reply_text += event.text + " "
reply_text = reply_text.strip()
if not reply_text:
reply_text = "Our systems are currently overloaded due to heavy usage on the free plan. Please try again in a moment."
else:
print("Processing direct response...")
reply_text = str(agent_response)
# Try to extract a CHOICE tag from final reply if coordinator appended it
# Expected format: a final line like [CHOICE:about_college_agent]
import re
choice_match = re.search(r"\[CHOICE:([a-zA-Z0-9_\-]+)\]", reply_text)
chosen_agent = None
if choice_match:
chosen_agent = choice_match.group(1)
# Remove the tag from the reply_text before returning to user
reply_text = re.sub(r"\n?\[CHOICE:[a-zA-Z0-9_\-]+\]\n?", "", reply_text).strip()
# Persist chosen_agent into session metadata if possible
if chosen_agent:
try:
# Prefer an update_session or set_session method if available
if hasattr(session_service, 'update_session'):
try:
session_service.update_session(
app_name="coordinator_agent",
user_id=req.user_id,
session_id=req.session_id,
metadata={"last_agent": chosen_agent}
)
except TypeError:
# Some implementations might have a different signature
session_service.update_session(req.user_id, req.session_id, {"last_agent": chosen_agent})
else:
# Fallback: try to set a key on the session object if it's a dict-like
if existing_session and isinstance(existing_session, dict):
existing_session.setdefault('metadata', {})['last_agent'] = chosen_agent
except Exception as persist_err:
print(f"Failed to persist chosen_agent to session: {persist_err}")
print(f"Final reply: {reply_text}")
reply_text=reply_text.replace("`","")
reply_text=reply_text.replace("\n\n\n","\n\n")
reply_text = re.sub(r'(?<!\*)\*(?!\*)', '', reply_text)
return ChatResponse(session_id=req.session_id, answer=reply_text)
except Exception as e:
print(f"Error occurred: {str(e)}")
print(f"Error type: {type(e)}")
import traceback
print(f"Full traceback: {traceback.format_exc()}")
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
# Add health check endpoint
@app.get("/")
async def root():
return {
"message": "PreBot College Counselor API is running!",
"status": "healthy",
"version": "1.0.0",
"endpoints": {
"chat": "/chat",
"docs": "/docs",
"redoc": "/redoc"
}
}
@app.get("/health")
async def health_check():
return {"status": "healthy", "timestamp": datetime.datetime.now().isoformat()}