File size: 9,048 Bytes
edc9558
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ded8838
 
edc9558
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a698773
edc9558
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
"""
Airflow execution bridge for AutoMLOps.

Triggers a real Airflow DAG run, then watches Airflow's metadata DB for
task-state changes and mirrors them into the same ``pipeline_executions``
dict that the existing ``/api/pipeline/status/<exec_id>`` endpoint reads.

The frontend never needs to know Airflow is running β€” it polls the same
Flask status endpoint it always did.
"""
from __future__ import annotations
import uuid, time, threading, logging
from datetime import datetime

from pipelines.dag_engine import pipeline_executions, _lock

logger = logging.getLogger(__name__)

# Maps Airflow task states β†’ the three states the frontend understands
_AF_STATE: dict[str | None, str] = {
    None:              "pending",
    "queued":          "pending",
    "scheduled":       "pending",
    "deferred":        "pending",
    "running":         "running",
    "success":         "success",
    "skipped":         "success",
    "failed":          "failed",
    "upstream_failed": "failed",
    "removed":         "failed",
}


def _fe_state(af: str | None) -> str:
    return _AF_STATE.get(af, "pending")


# ── watcher thread ────────────────────────────────────────────────────────────

def _watch(exec_id: str, dag_id: str, run_id: str, task_ids: list[str], task_names: dict[str, str]):
    """
    Polls the Airflow metadata DB and pushes updates into pipeline_executions.
    Exits when the DAG run reaches a terminal state (success / failed).
    """
    try:
        from airflow.models import DagRun, TaskInstance
        from airflow.utils.session import create_session
    except ImportError:
        logger.error("Airflow is not installed β€” watcher thread cannot run")
        with _lock:
            if exec_id in pipeline_executions:
                pipeline_executions[exec_id]["status"] = "failed"
                pipeline_executions[exec_id]["error"]  = "Airflow not installed"
        return

    seen_states: dict[str, str] = {tid: "pending" for tid in task_ids}

    for _attempt in range(900):   # max ~15 min of polling
        time.sleep(1.0)
        try:
            with create_session() as session:
                dag_run = session.query(DagRun).filter(
                    DagRun.dag_id == dag_id,
                    DagRun.run_id == run_id,
                ).first()

                if dag_run is None:
                    continue          # scheduler hasn't picked it up yet

                tis = {
                    ti.task_id: ti
                    for ti in session.query(TaskInstance).filter(
                        TaskInstance.dag_id == dag_id,
                        TaskInstance.run_id == run_id,
                    ).all()
                }

            now       = datetime.utcnow().strftime("%H:%M:%S")
            done_cnt  = 0

            with _lock:
                exec_st = pipeline_executions.get(exec_id)
                if exec_st is None:
                    return

                for tid in task_ids:
                    ti    = tis.get(tid)
                    af_st = ti.state if ti else None
                    fe_st = _fe_state(af_st)
                    prev  = seen_states[tid]

                    if fe_st == prev:
                        if fe_st in ("success", "failed"):
                            done_cnt += 1
                        continue

                    seen_states[tid] = fe_st
                    name = task_names.get(tid, tid)

                    if fe_st == "running":
                        exec_st["task_states"][tid]["status"]     = "running"
                        exec_st["task_states"][tid]["started_at"] = (
                            ti.start_date.isoformat() if ti and ti.start_date else None
                        )
                        exec_st["logs"].append(f"[{now}] β–Ά  {name}")

                    elif fe_st == "success":
                        dur = round(ti.duration, 1) if ti and ti.duration else 0
                        exec_st["task_states"][tid]["status"]      = "success"
                        exec_st["task_states"][tid]["result"]      = f"Completed in {dur}s"
                        exec_st["task_states"][tid]["finished_at"] = (
                            ti.end_date.isoformat() if ti and ti.end_date else None
                        )
                        exec_st["logs"].append(f"[{now}] βœ”  {name} β€” {dur}s")
                        done_cnt += 1

                    elif fe_st == "failed":
                        exec_st["task_states"][tid]["status"]      = "failed"
                        exec_st["task_states"][tid]["error"]       = "Task failed in Airflow"
                        exec_st["task_states"][tid]["finished_at"] = (
                            ti.end_date.isoformat() if ti and ti.end_date else None
                        )
                        exec_st["logs"].append(f"[{now}] βœ–  {name} β€” failed")
                        done_cnt += 1

                total = len(task_ids) or 1
                exec_st["progress"] = int(100 * done_cnt / total)
                exec_st["status"]   = "running"

            # Check terminal state of the whole DAG run
            dag_state = str(dag_run.state) if dag_run else "running"
            if dag_state == "success":
                with _lock:
                    if exec_id in pipeline_executions:
                        pipeline_executions[exec_id]["status"]       = "completed"
                        pipeline_executions[exec_id]["progress"]     = 100
                        pipeline_executions[exec_id]["completed_at"] = datetime.utcnow().isoformat()
                        pipeline_executions[exec_id]["logs"].append(
                            f"[{now}] βœ”  DAG '{dag_id}' completed successfully"
                        )
                return

            elif dag_state in ("failed", "upstream_failed"):
                with _lock:
                    if exec_id in pipeline_executions:
                        pipeline_executions[exec_id]["status"] = "failed"
                        pipeline_executions[exec_id]["error"]  = "DAG run failed in Airflow"
                        pipeline_executions[exec_id]["logs"].append(
                            f"[{now}] βœ–  DAG '{dag_id}' failed"
                        )
                return

        except Exception as exc:
            logger.warning(f"[watcher] poll error: {exc}")

    # Timed out
    with _lock:
        if exec_id in pipeline_executions:
            pipeline_executions[exec_id]["status"] = "failed"
            pipeline_executions[exec_id]["error"]  = "Execution watcher timed out"


# ── public API ────────────────────────────────────────────────────────────────

def trigger_pipeline(pipeline_id: str, context: dict | None = None, dag=None) -> str:
    """
    Trigger an Airflow DAG run and return an exec_id compatible with the
    existing pipeline_executions / status endpoint contract.

    ``dag`` is the DAG object from pipeline_defs.py (used for task metadata).
    """
    from airflow.api.common.trigger_dag import trigger_dag as af_trigger

    ts      = datetime.utcnow().strftime("%Y%m%dT%H%M%S")
    run_id  = f"automlops__{ts}"
    exec_id = str(uuid.uuid4())[:8]

    dag_id = pipeline_id   # our pipeline IDs match Airflow DAG IDs exactly

    # Fire the Airflow DAG run
    af_trigger(dag_id=dag_id, run_id=run_id, conf=context or {}, replace_microseconds=False)

    # Collect task metadata from the pipeline_defs DAG object
    task_ids   = list(dag.tasks.keys())   if dag else []
    task_names = {tid: dag.tasks[tid].name for tid in task_ids} if dag else {}

    # Initialise exec state (same schema as dag_engine.execute_dag)
    task_states = {
        tid: {"status": "pending", "started_at": None,
              "finished_at": None, "result": None, "error": None}
        for tid in task_ids
    }
    now = datetime.utcnow().strftime("%H:%M:%S")
    with _lock:
        pipeline_executions[exec_id] = {
            "exec_id":     exec_id,
            "dag_id":      dag_id,
            "run_id":      run_id,
            "dag_name":    dag.name if dag else dag_id,
            "status":      "queued",
            "progress":    0,
            "task_states": task_states,
            "logs": [],
            "created_at":  datetime.utcnow().isoformat(),
        }

    # Start the watcher thread
    threading.Thread(
        target=_watch,
        args=(exec_id, dag_id, run_id, task_ids, task_names),
        daemon=True,
    ).start()

    return exec_id


def is_available() -> bool:
    """Return True if Airflow is installed and the scheduler DB is reachable."""
    try:
        from airflow.utils.session import create_session
        with create_session():
            pass
        return True
    except Exception:
        return False