File size: 9,593 Bytes
d2cd3e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10d93ef
 
d2cd3e3
 
 
 
 
10d93ef
 
 
 
 
 
 
 
 
 
 
d2cd3e3
10d93ef
 
 
 
d2cd3e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f9f426
d2cd3e3
 
 
10d93ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2cd3e3
 
10d93ef
 
 
d2cd3e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from agents import coordinator
from google.adk.sessions import InMemorySessionService
from constants import INSTITUTE_MAPPING, BRANCH_MAPPING
from google.adk.tools import google_search
from google.adk.runners import Runner
from google.genai import types  # Add this import for Content and Part
from dotenv import load_dotenv
import os
import re
import datetime
 
# Load environment variables
load_dotenv()
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
 

app = FastAPI(
    title="PreBot College Counselor API",
    description="AI-powered college counseling system with multi-agent architecture",
    version="1.0.0",
    docs_url="/docs",
    redoc_url="/redoc"
)

# Enable CORS for all origins (adjust for production)
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Use a shared session service instance
session_service = InMemorySessionService()

class ChatRequest(BaseModel):
    user_id: str
    session_id: str
    question: str

class ChatResponse(BaseModel):
    session_id: str
    answer: str

def preprocess_query(query: str) -> str:
    sorted_institutes = sorted(INSTITUTE_MAPPING.keys(), key=len, reverse=True)
    for key in sorted_institutes:
        pattern = rf'\b{re.escape(key)}\b'
        query = re.sub(pattern, INSTITUTE_MAPPING[key][0], query, flags=re.IGNORECASE)

    for key, full_name in BRANCH_MAPPING.items():
        pattern = rf'\b{re.escape(key)}\b'
        query = re.sub(pattern, full_name, query, flags=re.IGNORECASE)

    return query

@app.options("/chat")
async def chat_options():
    return JSONResponse(
        content={"message": "OK"},
        headers={
            "Access-Control-Allow-Origin": "*",
            "Access-Control-Allow-Methods": "POST, OPTIONS",
            "Access-Control-Allow-Headers": "*",
        }
    )

@app.post("/chat", response_model=ChatResponse)
async def chat_endpoint(req: ChatRequest):
    try:
        print(f"Received request - User ID: {req.user_id}, Session ID: {req.session_id}")
        print(f"Question: {req.question}")
        
        # Check if session exists, create if not (methods are NOT async for InMemorySessionService)
        print("Checking for existing session...")
        try:
            existing_session = await session_service.get_session(
                app_name="coordinator_agent",
                user_id=req.user_id,
                session_id=req.session_id
            )
        except:
            existing_session = None
        
        if not existing_session:
            print("Creating new session...")
            try:
                await session_service.create_session(
                    app_name="coordinator_agent",
                    user_id=req.user_id,
                    session_id=req.session_id
                )
            except Exception as session_error:
                print(f"Session creation error: {session_error}")
        else:
            print("Using existing session")
        
        # Use the shared session service for the Runner
        print("Creating runner...")
        runner = Runner(
            agent=coordinator,
            app_name="coordinator_agent",
            session_service=session_service  # Use the shared session service
        )
        
        # Create properly formatted message using Google ADK types
        print("Processing query...")
        # Read last agent from session metadata (if available) so coordinator can honor follow-ups
        last_agent_name = None
        try:
            if existing_session and isinstance(existing_session, dict):
                # Some session implementations return a dict with metadata
                metadata = existing_session.get("metadata") or existing_session.get("meta") or {}
                if isinstance(metadata, dict):
                    last_agent_name = metadata.get("last_agent")
        except Exception as meta_err:
            print(f"Could not read session metadata: {meta_err}")

        processed_query = preprocess_query(req.question)

        # If we have a last_agent, prepend it in the agreed format so the coordinator can use it
        if last_agent_name:
            processed_query = f"LAST_AGENT: {last_agent_name}\n" + processed_query
        print(f"Processed query: {processed_query}")
        
        user_msg = types.Content(role="user", parts=[types.Part(text=processed_query)])
        
        print("Running agent...")
        agent_response = runner.run(
            user_id=req.user_id,
            session_id=req.session_id,
            new_message=user_msg,
        )
        
        # Process the generator response to extract the final answer
        print(f"Agent response type: {type(agent_response)}")
        reply_text = ""
        
        if hasattr(agent_response, '__iter__') and not isinstance(agent_response, str):
            print("Processing iterable response...")
            for event in agent_response:
                print(f"Processing event: {event}")
                
                # Try multiple ways to extract text from event
                if hasattr(event, 'is_final_response') and event.is_final_response():
                    if hasattr(event, 'content') and hasattr(event.content, 'parts'):
                        for part in event.content.parts:
                            if hasattr(part, 'text') and part.text:
                                reply_text = part.text
                                break
                        if reply_text:
                            break
                    elif hasattr(event, 'text'):
                        reply_text = event.text
                        break
                
                # Also try to get content from any event that has text
                if hasattr(event, 'content'):
                    if hasattr(event.content, 'parts'):
                        for part in event.content.parts:
                            if hasattr(part, 'text') and part.text:
                                reply_text += part.text + " "
                    elif hasattr(event.content, 'text'):
                        reply_text += event.content.text + " "
                elif hasattr(event, 'text'):
                    reply_text += event.text + " "
            
            reply_text = reply_text.strip()
            if not reply_text:
                reply_text = "Our systems are currently overloaded due to heavy usage on the free plan. Please try again in a moment."
        else:
            print("Processing direct response...")
            reply_text = str(agent_response)

        # Try to extract a CHOICE tag from final reply if coordinator appended it
        # Expected format: a final line like [CHOICE:about_college_agent]
        import re
        choice_match = re.search(r"\[CHOICE:([a-zA-Z0-9_\-]+)\]", reply_text)
        chosen_agent = None
        if choice_match:
            chosen_agent = choice_match.group(1)
            # Remove the tag from the reply_text before returning to user
            reply_text = re.sub(r"\n?\[CHOICE:[a-zA-Z0-9_\-]+\]\n?", "", reply_text).strip()

        # Persist chosen_agent into session metadata if possible
        if chosen_agent:
            try:
                # Prefer an update_session or set_session method if available
                if hasattr(session_service, 'update_session'):
                    try:
                        session_service.update_session(
                            app_name="coordinator_agent",
                            user_id=req.user_id,
                            session_id=req.session_id,
                            metadata={"last_agent": chosen_agent}
                        )
                    except TypeError:
                        # Some implementations might have a different signature
                        session_service.update_session(req.user_id, req.session_id, {"last_agent": chosen_agent})
                else:
                    # Fallback: try to set a key on the session object if it's a dict-like
                    if existing_session and isinstance(existing_session, dict):
                        existing_session.setdefault('metadata', {})['last_agent'] = chosen_agent
            except Exception as persist_err:
                print(f"Failed to persist chosen_agent to session: {persist_err}")
            
        print(f"Final reply: {reply_text}")
        reply_text=reply_text.replace("`","")
        reply_text=reply_text.replace("\n\n\n","\n\n")
        reply_text = re.sub(r'(?<!\*)\*(?!\*)', '', reply_text)
        return ChatResponse(session_id=req.session_id, answer=reply_text)
    except Exception as e:
        print(f"Error occurred: {str(e)}")
        print(f"Error type: {type(e)}")
        import traceback
        print(f"Full traceback: {traceback.format_exc()}")
        raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")

# Add health check endpoint
@app.get("/")
async def root():
    return {
        "message": "PreBot College Counselor API is running!",
        "status": "healthy",
        "version": "1.0.0",
        "endpoints": {
            "chat": "/chat",
            "docs": "/docs",
            "redoc": "/redoc"
        }
    }

@app.get("/health")
async def health_check():
    return {"status": "healthy", "timestamp": datetime.datetime.now().isoformat()}