Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel | |
| from agents import coordinator | |
| from google.adk.sessions import InMemorySessionService | |
| from constants import INSTITUTE_MAPPING, BRANCH_MAPPING | |
| from google.adk.tools import google_search | |
| from google.adk.runners import Runner | |
| from google.genai import types # Add this import for Content and Part | |
| from dotenv import load_dotenv | |
| import os | |
| import re | |
| import datetime | |
| # Load environment variables | |
| load_dotenv() | |
| GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") | |
| app = FastAPI( | |
| title="PreBot College Counselor API", | |
| description="AI-powered college counseling system with multi-agent architecture", | |
| version="1.0.0", | |
| docs_url="/docs", | |
| redoc_url="/redoc" | |
| ) | |
| # Enable CORS for all origins (adjust for production) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Use a shared session service instance | |
| session_service = InMemorySessionService() | |
| class ChatRequest(BaseModel): | |
| user_id: str | |
| session_id: str | |
| question: str | |
| class ChatResponse(BaseModel): | |
| session_id: str | |
| answer: str | |
| def preprocess_query(query: str) -> str: | |
| sorted_institutes = sorted(INSTITUTE_MAPPING.keys(), key=len, reverse=True) | |
| for key in sorted_institutes: | |
| pattern = rf'\b{re.escape(key)}\b' | |
| query = re.sub(pattern, INSTITUTE_MAPPING[key][0], query, flags=re.IGNORECASE) | |
| for key, full_name in BRANCH_MAPPING.items(): | |
| pattern = rf'\b{re.escape(key)}\b' | |
| query = re.sub(pattern, full_name, query, flags=re.IGNORECASE) | |
| return query | |
| async def chat_options(): | |
| return JSONResponse( | |
| content={"message": "OK"}, | |
| headers={ | |
| "Access-Control-Allow-Origin": "*", | |
| "Access-Control-Allow-Methods": "POST, OPTIONS", | |
| "Access-Control-Allow-Headers": "*", | |
| } | |
| ) | |
| async def chat_endpoint(req: ChatRequest): | |
| try: | |
| print(f"Received request - User ID: {req.user_id}, Session ID: {req.session_id}") | |
| print(f"Question: {req.question}") | |
| # Check if session exists, create if not (methods are NOT async for InMemorySessionService) | |
| print("Checking for existing session...") | |
| try: | |
| existing_session = await session_service.get_session( | |
| app_name="coordinator_agent", | |
| user_id=req.user_id, | |
| session_id=req.session_id | |
| ) | |
| except: | |
| existing_session = None | |
| if not existing_session: | |
| print("Creating new session...") | |
| try: | |
| await session_service.create_session( | |
| app_name="coordinator_agent", | |
| user_id=req.user_id, | |
| session_id=req.session_id | |
| ) | |
| except Exception as session_error: | |
| print(f"Session creation error: {session_error}") | |
| else: | |
| print("Using existing session") | |
| # Use the shared session service for the Runner | |
| print("Creating runner...") | |
| runner = Runner( | |
| agent=coordinator, | |
| app_name="coordinator_agent", | |
| session_service=session_service # Use the shared session service | |
| ) | |
| # Create properly formatted message using Google ADK types | |
| print("Processing query...") | |
| # Read last agent from session metadata (if available) so coordinator can honor follow-ups | |
| last_agent_name = None | |
| try: | |
| if existing_session and isinstance(existing_session, dict): | |
| # Some session implementations return a dict with metadata | |
| metadata = existing_session.get("metadata") or existing_session.get("meta") or {} | |
| if isinstance(metadata, dict): | |
| last_agent_name = metadata.get("last_agent") | |
| except Exception as meta_err: | |
| print(f"Could not read session metadata: {meta_err}") | |
| processed_query = preprocess_query(req.question) | |
| # If we have a last_agent, prepend it in the agreed format so the coordinator can use it | |
| if last_agent_name: | |
| processed_query = f"LAST_AGENT: {last_agent_name}\n" + processed_query | |
| print(f"Processed query: {processed_query}") | |
| user_msg = types.Content(role="user", parts=[types.Part(text=processed_query)]) | |
| print("Running agent...") | |
| agent_response = runner.run( | |
| user_id=req.user_id, | |
| session_id=req.session_id, | |
| new_message=user_msg, | |
| ) | |
| # Process the generator response to extract the final answer | |
| print(f"Agent response type: {type(agent_response)}") | |
| reply_text = "" | |
| if hasattr(agent_response, '__iter__') and not isinstance(agent_response, str): | |
| print("Processing iterable response...") | |
| for event in agent_response: | |
| print(f"Processing event: {event}") | |
| # Try multiple ways to extract text from event | |
| if hasattr(event, 'is_final_response') and event.is_final_response(): | |
| if hasattr(event, 'content') and hasattr(event.content, 'parts'): | |
| for part in event.content.parts: | |
| if hasattr(part, 'text') and part.text: | |
| reply_text = part.text | |
| break | |
| if reply_text: | |
| break | |
| elif hasattr(event, 'text'): | |
| reply_text = event.text | |
| break | |
| # Also try to get content from any event that has text | |
| if hasattr(event, 'content'): | |
| if hasattr(event.content, 'parts'): | |
| for part in event.content.parts: | |
| if hasattr(part, 'text') and part.text: | |
| reply_text += part.text + " " | |
| elif hasattr(event.content, 'text'): | |
| reply_text += event.content.text + " " | |
| elif hasattr(event, 'text'): | |
| reply_text += event.text + " " | |
| reply_text = reply_text.strip() | |
| if not reply_text: | |
| reply_text = "Our systems are currently overloaded due to heavy usage on the free plan. Please try again in a moment." | |
| else: | |
| print("Processing direct response...") | |
| reply_text = str(agent_response) | |
| # Try to extract a CHOICE tag from final reply if coordinator appended it | |
| # Expected format: a final line like [CHOICE:about_college_agent] | |
| import re | |
| choice_match = re.search(r"\[CHOICE:([a-zA-Z0-9_\-]+)\]", reply_text) | |
| chosen_agent = None | |
| if choice_match: | |
| chosen_agent = choice_match.group(1) | |
| # Remove the tag from the reply_text before returning to user | |
| reply_text = re.sub(r"\n?\[CHOICE:[a-zA-Z0-9_\-]+\]\n?", "", reply_text).strip() | |
| # Persist chosen_agent into session metadata if possible | |
| if chosen_agent: | |
| try: | |
| # Prefer an update_session or set_session method if available | |
| if hasattr(session_service, 'update_session'): | |
| try: | |
| session_service.update_session( | |
| app_name="coordinator_agent", | |
| user_id=req.user_id, | |
| session_id=req.session_id, | |
| metadata={"last_agent": chosen_agent} | |
| ) | |
| except TypeError: | |
| # Some implementations might have a different signature | |
| session_service.update_session(req.user_id, req.session_id, {"last_agent": chosen_agent}) | |
| else: | |
| # Fallback: try to set a key on the session object if it's a dict-like | |
| if existing_session and isinstance(existing_session, dict): | |
| existing_session.setdefault('metadata', {})['last_agent'] = chosen_agent | |
| except Exception as persist_err: | |
| print(f"Failed to persist chosen_agent to session: {persist_err}") | |
| print(f"Final reply: {reply_text}") | |
| reply_text=reply_text.replace("`","") | |
| reply_text=reply_text.replace("\n\n\n","\n\n") | |
| reply_text = re.sub(r'(?<!\*)\*(?!\*)', '', reply_text) | |
| return ChatResponse(session_id=req.session_id, answer=reply_text) | |
| except Exception as e: | |
| print(f"Error occurred: {str(e)}") | |
| print(f"Error type: {type(e)}") | |
| import traceback | |
| print(f"Full traceback: {traceback.format_exc()}") | |
| raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") | |
| # Add health check endpoint | |
| async def root(): | |
| return { | |
| "message": "PreBot College Counselor API is running!", | |
| "status": "healthy", | |
| "version": "1.0.0", | |
| "endpoints": { | |
| "chat": "/chat", | |
| "docs": "/docs", | |
| "redoc": "/redoc" | |
| } | |
| } | |
| async def health_check(): | |
| return {"status": "healthy", "timestamp": datetime.datetime.now().isoformat()} | |