File size: 3,724 Bytes
327bfe3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f817a0
327bfe3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
"""
core.py — LLM factory, PII-safe logging, thread-safe SSE event bus.
"""
import json
import os
import queue
import threading
from pathlib import Path
from typing import Any, Dict, Optional

from dotenv import load_dotenv
from langchain_openai import ChatOpenAI

ROOT_DIR = Path(__file__).parent
LOG_DIR  = ROOT_DIR / "logs"
LOG_DIR.mkdir(exist_ok=True)

load_dotenv()

# ── Thread-safe SSE event bus ─────────────────────────────────────────────────
# We use a threading.local so each Flask request gets its own queue,
# BUT the parallel worker threads spawned by phase1/phase4 need to
# inherit the same queue from their parent thread.
# Solution: store the queue in a module-level dict keyed by "root thread id".

_queue_registry: Dict[int, queue.Queue] = {}
_registry_lock  = threading.Lock()
_thread_to_root: Dict[int, int] = {}   # child thread → root request thread


def set_event_queue(q: queue.Queue) -> None:
    """Called once per Flask request (in the request's own thread)."""
    tid = threading.get_ident()
    with _registry_lock:
        _queue_registry[tid] = q
        _thread_to_root[tid] = tid   # root maps to itself


def register_child_thread(root_tid: int) -> None:
    """
    Called at the start of each parallel worker thread so it can find
    the correct SSE queue to push events into.
    """
    tid = threading.get_ident()
    with _registry_lock:
        _thread_to_root[tid] = root_tid


def get_event_queue() -> Optional[queue.Queue]:
    tid = threading.get_ident()
    with _registry_lock:
        root = _thread_to_root.get(tid, tid)
        return _queue_registry.get(root)


def clear_event_queue() -> None:
    tid = threading.get_ident()
    with _registry_lock:
        _queue_registry.pop(tid, None)
        # Clean up child mappings for this root
        to_del = [k for k, v in _thread_to_root.items() if v == tid]
        for k in to_del:
            _thread_to_root.pop(k, None)


def push_event(event_type: str, data: Dict[str, Any]) -> None:
    q = get_event_queue()
    if q:
        q.put({"type": event_type, "data": data})


# ── App config ────────────────────────────────────────────────────────────────
def is_deterministic() -> bool:
    return os.getenv("AISA_DETERMINISTIC", "0") == "1"


def get_llm(model: str = "gpt-4o-mini") -> ChatOpenAI:
    temperature = 0.0 if is_deterministic() else 0.4
    return ChatOpenAI(model=model, temperature=temperature)


# ── PII-safe logging ──────────────────────────────────────────────────────────
_SENSITIVE_KEYS = {"api_key", "headers", "text", "summary", "description",
                "query", "user_text"}
_log_lock = threading.Lock()


def log_step(agent: str, event: str, payload: Dict[str, Any]) -> None:
    safe: Dict[str, Any] = {}
    for k, v in payload.items():
        if k in _SENSITIVE_KEYS:
            continue
        if isinstance(v, str) and len(v) > 200:
            v = v[:200] + "…[truncated]"
        safe[k] = v

    record = {"agent": agent, "event": event, "payload": safe}

    # Disk write (thread-safe via lock)
    try:
        with _log_lock:
            with (LOG_DIR / "agent_steps.log").open("a", encoding="utf-8") as f:
                f.write(json.dumps(record, ensure_ascii=False) + "\n")
    except OSError:
        pass

    # SSE push
    push_event("agent_step", {"agent": agent, "event": event, "payload": safe})