File size: 20,212 Bytes
ab07cb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d5ca8f
6b1a94c
ab07cb1
 
 
 
 
 
 
 
 
 
 
 
973d6a7
 
5052577
973d6a7
 
 
 
 
ab07cb1
 
 
 
 
 
973d6a7
ab07cb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31d2c72
ab07cb1
 
 
 
1c9cb5b
 
 
 
 
 
 
31d2c72
ab07cb1
 
 
 
 
 
 
 
 
1c9cb5b
 
 
 
ab07cb1
 
 
 
 
 
 
 
 
1c9cb5b
 
 
 
 
 
 
ab07cb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
973d6a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0bd6c6
 
 
 
 
 
 
ab07cb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0bd6c6
 
 
 
 
ab07cb1
 
 
 
 
 
 
 
 
 
 
 
 
973d6a7
 
ab07cb1
 
 
b0bd6c6
 
 
ab07cb1
 
 
 
 
 
7d5ca8f
 
ab07cb1
7d5ca8f
ab07cb1
7d5ca8f
ab07cb1
7d5ca8f
ab07cb1
 
 
 
 
 
 
 
 
 
 
7d5ca8f
0f698a2
7d5ca8f
0f698a2
 
7d5ca8f
 
 
0f698a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d5ca8f
0f698a2
 
 
 
 
 
 
 
 
 
 
 
7d5ca8f
 
6b1a94c
 
 
 
 
 
 
 
7d5ca8f
 
 
 
 
 
 
 
ab07cb1
 
 
 
973d6a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab07cb1
973d6a7
ab07cb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d5ca8f
 
 
 
ab07cb1
 
 
 
 
 
b0bd6c6
 
ab07cb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53b38bc
ab07cb1
53b38bc
 
 
 
 
 
ab07cb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
"""
Agent Wrapper for Web Interface
===============================
Wraps the LangChain agent for WebSocket streaming.
"""

import os
import sys
import asyncio
import logging
from pathlib import Path
from typing import Optional, Callable, Any, List, Dict
from queue import Queue

# Add src directory to path for eurus package
PROJECT_ROOT = Path(__file__).parent.parent
sys.path.insert(0, str(PROJECT_ROOT))
sys.path.insert(0, str(PROJECT_ROOT / "src"))

from dotenv import load_dotenv
load_dotenv()

from langchain_openai import ChatOpenAI
from langchain.agents import create_agent

# IMPORT FROM EURUS PACKAGE - SINGLE SOURCE OF TRUTH
from eurus.config import CONFIG, AGENT_SYSTEM_PROMPT
from eurus.retrieval import _arraylake_snippet
from eurus.tools.era5 import _auto_detect_query_type
from eurus.memory import get_memory, SmartConversationMemory  # Singleton for datasets, per-session for chat
from eurus.tools import get_all_tools
from eurus.tools.repl import PythonREPLTool

logger = logging.getLogger(__name__)


class AgentSession:
    """
    Manages a single agent session with streaming support.
    """

    # Available models for the selector
    AVAILABLE_MODELS = [
        {"id": "gpt-5.4", "label": "GPT-5.4", "provider": "openai"},
        {"id": "gpt-4.1", "label": "GPT-4.1", "provider": "openai"},
        {"id": "o3", "label": "o3", "provider": "openai"},
        {"id": "gemini-3.1-pro-preview", "label": "Gemini 3.1 Pro", "provider": "google"},
    ]

    def __init__(self, api_keys: Optional[Dict[str, str]] = None):
        self._agent = None
        self._repl_tool: Optional[PythonREPLTool] = None
        self._messages: List[Dict] = []
        self._initialized = False
        self._api_keys = api_keys or {}
        self._current_model = CONFIG.model_name

        # Global singleton keeps the dataset cache (shared across sessions)
        self._memory = get_memory()
        # Per-session conversation memory β€” never touches other sessions
        self._conversation = SmartConversationMemory()

        # Queue for captured plots (thread-safe)
        self._plot_queue: Queue = Queue()

        self._initialize()

    def _initialize(self):
        """Initialize the agent and tools."""
        logger.info("Initializing agent session...")

        # Resolve API keys: user-provided take priority over env vars
        openai_key = self._api_keys.get("openai_api_key") or os.environ.get("OPENAI_API_KEY")
        arraylake_key = self._api_keys.get("arraylake_api_key") or os.environ.get("ARRAYLAKE_API_KEY")
        hf_token = self._api_keys.get("hf_token") or os.environ.get("HF_TOKEN")

        if not arraylake_key:
            logger.warning("ARRAYLAKE_API_KEY not found")

        # SECURITY: Do NOT write user-provided keys to os.environ!
        # os.environ is process-global β€” leaks keys to other sessions on shared hosts (e.g. HF Spaces).
        # Instead, store in self and pass directly to tools that need them.
        self._resolved_keys = {
            "ARRAYLAKE_API_KEY": arraylake_key or "",
            "HF_TOKEN": hf_token or "",
        }

        if not openai_key:
            logger.error("OPENAI_API_KEY not found")
            return

        try:
            # Initialize REPL tool with working directory
            logger.info("Starting Python kernel...")
            self._repl_tool = PythonREPLTool(working_dir=os.getcwd())

            # Inject session-specific keys into the REPL subprocess
            # (keeps them isolated from other sessions β€” no os.environ pollution)
            self._repl_tool.inject_env(self._resolved_keys)

            # Set up plot callback using the proper method
            def on_plot_captured(base64_data: str, filepath: str, code: str = ""):
                logger.info(f"Plot captured, adding to queue: {filepath}")
                self._plot_queue.put((base64_data, filepath, code))

            self._repl_tool.set_plot_callback(on_plot_captured)
            logger.info("Plot callback registered")

            # Get ALL tools from centralized registry (no SCIENCE_TOOLS!)
            # Pass session-specific Arraylake key for isolation
            arraylake_key = self._resolved_keys.get("ARRAYLAKE_API_KEY")
            tools = get_all_tools(
                enable_routing=True,
                enable_guide=True,
                arraylake_api_key=arraylake_key or None,
            )
            # Replace the default REPL with our configured one
            tools = [t for t in tools if t.name != "python_repl"] + [self._repl_tool]

            # Initialize LLM with resolved key
            logger.info("Connecting to LLM...")
            llm = ChatOpenAI(
                model=CONFIG.model_name,
                temperature=CONFIG.temperature,
                api_key=openai_key,
            )

            # Use session-local memory for datasets (NOT global!)
            datasets = self._memory.list_datasets()
            enhanced_prompt = AGENT_SYSTEM_PROMPT
            
            if datasets != "No datasets in cache.":
                enhanced_prompt += f"\n\n## CACHED DATASETS\n{datasets}"

            # Create agent
            logger.info("Creating agent...")
            self._agent = create_agent(
                model=llm,
                tools=tools,
                system_prompt=enhanced_prompt,
                debug=False
            )

            # FRESH conversation - no old messages!
            self._messages = []

            self._initialized = True
            logger.info("Agent session initialized successfully")

        except Exception as e:
            logger.exception(f"Failed to initialize agent: {e}")
            self._initialized = False

    def is_ready(self) -> bool:
        """Check if the agent is ready."""
        return self._initialized and self._agent is not None

    def get_current_model(self) -> str:
        """Return the current model name."""
        return self._current_model

    def set_provider(self, model_id: str):
        """Switch the LLM model. Reinitializes the agent with the new model."""
        openai_key = self._api_keys.get("openai_api_key") or os.environ.get("OPENAI_API_KEY")
        vertex_key = self._api_keys.get("vertex_api_key") or os.environ.get("vertex_api_key")

        # Determine provider from model id
        is_gemini = model_id.startswith("gemini")

        if is_gemini and not vertex_key:
            logger.error("Cannot switch to Gemini: no vertex_api_key in .env")
            return
        if not is_gemini and not openai_key:
            logger.error("Cannot switch model: no OPENAI_API_KEY")
            return

        logger.info(f"Switching model from {self._current_model} to {model_id}")
        self._current_model = model_id

        try:
            if is_gemini:
                from langchain_google_genai import ChatGoogleGenerativeAI
                llm = ChatGoogleGenerativeAI(
                    model=model_id,
                    temperature=CONFIG.temperature,
                    api_key=vertex_key,
                    vertexai=True,
                )
            else:
                llm = ChatOpenAI(
                    model=model_id,
                    temperature=CONFIG.temperature,
                    api_key=openai_key,
                )

            tools = get_all_tools(enable_routing=True, enable_guide=True)
            tools = [t for t in tools if t.name != "python_repl"] + [self._repl_tool]

            datasets = self._memory.list_datasets()
            enhanced_prompt = AGENT_SYSTEM_PROMPT
            if datasets != "No datasets in cache.":
                enhanced_prompt += f"\n\n## CACHED DATASETS\n{datasets}"

            self._agent = create_agent(
                model=llm,
                tools=tools,
                system_prompt=enhanced_prompt,
                debug=False
            )

            # Keep conversation intact β€” only reset tool calls
            self._messages = []
            logger.info(f"Model switched to {model_id} successfully")
        except Exception as e:
            logger.exception(f"Failed to switch model: {e}")

    def reinitialize(self):
        """Retry initialization (e.g., after transient failure)."""
        logger.warning("Attempting agent reinitialization...")
        self._initialized = False
        self._agent = None
        self._initialize()

    def clear_messages(self):
        """Clear conversation messages."""
        self._messages = []

    def get_pending_plots(self) -> List[tuple]:
        """Get all pending plots from queue."""
        plots = []
        while not self._plot_queue.empty():
            try:
                plots.append(self._plot_queue.get_nowait())
            except Exception:
                break
        return plots

    async def process_message(
        self,
        user_message: str,
        stream_callback: Callable
    ) -> str:
        """
        Process a user message and stream the response.
        """
        if not self.is_ready():
            # Try to reinitialize once before giving up
            logger.warning("Agent not ready, attempting reinitialization...")
            self.reinitialize()
            if not self.is_ready():
                raise RuntimeError("Agent not initialized")

        # Clear any old plots from queue
        self.get_pending_plots()

        # Add user message to history (session-local memory)
        self._conversation.add_message("user", user_message)
        self._messages.append({"role": "user", "content": user_message})

        try:
            # Send status: analyzing
            await stream_callback("status", "πŸ” Analyzing your request...")
            await asyncio.sleep(0.3)

            # Invoke the agent in executor (20 iterations max to save tokens)
            config = {"recursion_limit": 20}
            
            # Stream status updates while agent is working
            await stream_callback("status", "πŸ€– Processing with AI...")

            # Save message state before invoke (protect against corruption)
            messages_backup = list(self._messages)
            
            result = await asyncio.get_event_loop().run_in_executor(
                None,
                lambda: self._agent.invoke({"messages": self._messages}, config=config)
            )

            # Only scan NEW messages from this turn
            prev_count = len(self._messages)
            self._messages = result["messages"]
            new_messages = self._messages[prev_count:]
            
            # Parse NEW messages to show tool calls made
            tool_calls_made = []
            for msg in new_messages:
                if hasattr(msg, 'tool_calls') and msg.tool_calls:
                    for tc in msg.tool_calls:
                        tool_name = tc.get('name', 'unknown')
                        if tool_name not in tool_calls_made:
                            tool_calls_made.append(tool_name)
                            
            if tool_calls_made:
                tools_str = ", ".join(tool_calls_made)
                await stream_callback("status", f"πŸ› οΈ Used tools: {tools_str}")
                await asyncio.sleep(0.5)

            # Collect Arraylake snippet from NEW messages only
            # Only emit ONE snippet per unique (variable, region) β€” skip failed calls
            arraylake_snippets = []
            seen_snippet_keys = set()
            for i, msg in enumerate(new_messages):
                if hasattr(msg, 'tool_calls') and msg.tool_calls:
                    for tc in msg.tool_calls:
                        if tc.get('name') == 'retrieve_era5_data':
                            # Check if tool call succeeded by looking at the next message
                            # (ToolMessage with same tool_call_id)
                            tc_id = tc.get('id', '')
                            succeeded = True
                            for later_msg in new_messages[i+1:]:
                                if (hasattr(later_msg, 'tool_call_id') and
                                        later_msg.tool_call_id == tc_id):
                                    content = getattr(later_msg, 'content', '') or ''
                                    if any(kw in content.lower() for kw in
                                           ['error', 'failed', 'exception', 'limit',
                                            'exceeded', 'rejected', 'too large']):
                                        succeeded = False
                                    break

                            if not succeeded:
                                continue

                            args = tc.get('args', {})
                            # Dedup key: variable + rounded region
                            dedup_key = (
                                args.get('variable_id', 'sst'),
                                round(args.get('min_latitude', -90)),
                                round(args.get('max_latitude', 90)),
                                round(args.get('min_longitude', 0)),
                                round(args.get('max_longitude', 360)),
                            )
                            if dedup_key in seen_snippet_keys:
                                continue
                            seen_snippet_keys.add(dedup_key)

                            arraylake_snippets.append(_arraylake_snippet(
                                variable=args.get('variable_id', 'sst'),
                                query_type=_auto_detect_query_type(
                                    start_date=args.get('start_date', ''),
                                    end_date=args.get('end_date', ''),
                                    min_lat=args.get('min_latitude', -90),
                                    max_lat=args.get('max_latitude', 90),
                                    min_lon=args.get('min_longitude', 0),
                                    max_lon=args.get('max_longitude', 360),
                                ),
                                start_date=args.get('start_date', ''),
                                end_date=args.get('end_date', ''),
                                min_lat=args.get('min_latitude', -90),
                                max_lat=args.get('max_latitude', 90),
                                min_lon=args.get('min_longitude', 0),
                                max_lon=args.get('max_longitude', 360),
                            ))

            # Extract response
            last_message = self._messages[-1]

            if hasattr(last_message, 'content') and last_message.content:
                raw_content = last_message.content
                # Gemini can return content as a list of content blocks
                if isinstance(raw_content, list):
                    # Extract text from each block
                    parts = []
                    for block in raw_content:
                        if isinstance(block, str):
                            parts.append(block)
                        elif isinstance(block, dict) and block.get('text'):
                            parts.append(block['text'])
                        elif hasattr(block, 'text'):
                            parts.append(block.text)
                    response_text = "\n".join(parts) if parts else str(raw_content)
                else:
                    response_text = str(raw_content)
            elif isinstance(last_message, dict) and last_message.get('content'):
                response_text = str(last_message['content'])
            else:
                response_text = str(last_message)

            # Send status: generating response
            await stream_callback("status", "✍️ Generating response...")
            await asyncio.sleep(0.2)

            # Stream the response in chunks
            chunk_size = 50
            for i in range(0, len(response_text), chunk_size):
                chunk = response_text[i:i + chunk_size]
                await stream_callback("chunk", chunk)
                await asyncio.sleep(0.01)

            # Send any captured media (plots and videos)
            plots = self.get_pending_plots()
            # NOTE: Only use session-specific _plot_queue, NOT shared folder scan (privacy!)
            
            if plots:
                await stream_callback("status", f"πŸ“Š Rendering {len(plots)} visualization(s)...")
                await asyncio.sleep(0.3)
                
            logger.info(f"Sending {len(plots)} media items to client")
            for plot_data in plots:
                base64_data, filepath = plot_data[0], plot_data[1]
                code = plot_data[2] if len(plot_data) > 2 else ""
                
                # Determine if this is a video or image
                ext = filepath.lower().split('.')[-1] if filepath else ''
                if ext in ('gif',):
                    await stream_callback("video", "", data=base64_data, path=filepath, mimetype="image/gif")
                elif ext in ('webm',):
                    await stream_callback("video", "", data=base64_data, path=filepath, mimetype="video/webm")
                elif ext in ('mp4',):
                    await stream_callback("video", "", data=base64_data, path=filepath, mimetype="video/mp4")
                else:
                    # Default to plot (png, jpg, etc.)
                    await stream_callback("plot", "", data=base64_data, path=filepath, code=code)

            # Send Arraylake snippets AFTER response + plots exist in DOM
            for snippet in arraylake_snippets:
                await stream_callback("arraylake_snippet", snippet)

            # Save to memory
            self._conversation.add_message("assistant", response_text)

            return response_text

        except Exception as e:
            # Restore clean message state to prevent corruption on next call
            self._messages = messages_backup
            logger.exception(f"Error processing message: {e}")
            raise

    def close(self):
        """Clean up resources."""
        logger.info("Closing agent session...")
        if self._repl_tool:
            try:
                self._repl_tool.close()
            except Exception as e:
                logger.error(f"Error closing REPL: {e}")


# Per-connection sessions (NOT global singleton!)
# Key: unique connection ID, Value: AgentSession
_sessions: Dict[str, AgentSession] = {}


def create_session(connection_id: str, api_keys: Optional[Dict[str, str]] = None) -> AgentSession:
    """Create a new session for a connection (reuses if already ready)."""
    if connection_id in _sessions:
        existing = _sessions[connection_id]
        if existing.is_ready():
            logger.info(f"Reusing existing ready session for: {connection_id}")
            return existing
        # Close broken session before replacing
        existing.close()
    session = AgentSession(api_keys=api_keys)
    _sessions[connection_id] = session
    logger.info(f"Created session for connection: {connection_id}")
    return session


def get_session(connection_id: str) -> Optional[AgentSession]:
    """Get session for a connection."""
    return _sessions.get(connection_id)


def close_session(connection_id: str):
    """Close and remove session for a connection."""
    if connection_id in _sessions:
        _sessions[connection_id].close()
        del _sessions[connection_id]
        logger.info(f"Closed session for connection: {connection_id}")


# DEPRECATED: Keep for backward compatibility during migration
def get_agent_session() -> AgentSession:
    """DEPRECATED: Use create_session/get_session with connection_id instead."""
    logger.warning("get_agent_session() is deprecated - use create_session(connection_id)")
    # Create default session for CLI/testing
    if "_default" not in _sessions:
        _sessions["_default"] = AgentSession()
    return _sessions["_default"]


def shutdown_agent_session():
    """Shutdown all agent sessions."""
    count = len(_sessions)
    for conn_id in list(_sessions.keys()):
        close_session(conn_id)
    logger.info(f"Shutdown {count} sessions")