File size: 4,516 Bytes
2c29579
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1120492
2c29579
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1120492
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import os
import json
import csv
from celery import Celery

from infra.database import get_db, JobModel, DatasetModel
from infra.logger import get_logger
from infra.result_contract import normalize_results

log = get_logger(__name__)
from core.insights import generate_insights, generate_story
from core.meta_learning import save_meta_record
from core.pipeline_engine import PipelineEngine, PipelineContext
from services.training.components import (
    DataValidationComponent,
    FeatureEngineeringComponent,
    ModelSelectionComponent,
    TrainingComponent,
    EvaluationComponent
)

# ── CSV Field Size Limit ──────────────────────────────────────────────────────
csv.field_size_limit(int(1e9))

REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379/0")

celery_app = Celery(
    "automl_worker",
    broker=REDIS_URL,
    backend=REDIS_URL
)
celery_app.conf.broker_connection_retry_on_startup = True

@celery_app.task(bind=True, max_retries=0)
def run_training_job(
    self, job_id, dataset_id, file_path, target_column, goal, mode,
    eval_metric="Performance",
    selected_features=None,
    handle_imbalance=False,
    auto_clean=True,
    cv_folds=0,
):
    """
    Celery task: runs training in background using the Modular Component Pipeline Engine.
    """
    # Fetch health metadata / profile to prep context
    profile_data = {}
    health_metadata = {}
    try:
        with get_db() as db:
            ds = db.query(DatasetModel).filter(DatasetModel.id == dataset_id).first()
            if ds and ds.profile_json:
                try:
                    profile_data = json.loads(ds.profile_json)
                    health_metadata = profile_data.get("health", {})
                except Exception:
                    profile_data = {}
                    health_metadata = {}
    except Exception:
        profile_data = {}
        health_metadata = {}

    config = {
        "eval_metric": eval_metric,
        "selected_features": selected_features,
        "handle_imbalance": handle_imbalance,
        "auto_clean": auto_clean,
        "cv_folds": cv_folds
    }

    ctx = PipelineContext(
        job_id=job_id,
        dataset_id=dataset_id,
        file_path=file_path,
        target_column=target_column,
        goal=goal,
        mode=mode,
        config=config
    )
    ctx.health_metadata = health_metadata

    components = [
        DataValidationComponent(),
        FeatureEngineeringComponent(),
        ModelSelectionComponent(),
        TrainingComponent(),
        EvaluationComponent()
    ]

    engine = PipelineEngine(context=ctx, components=components)

    try:
        final_ctx = engine.run()
        results = normalize_results(final_ctx.metrics or {})

        try:
            insights = generate_insights(profile_data or {}, results)
            story = generate_story(profile_data or {}, results)
        except Exception as e:
            log.warning(f"Insights/story generation failed: {e}", exc_info=True)
            insights = {}
            story = None

        try:
            save_meta_record(profile_data or {}, results or {})
        except Exception as e:
            log.warning(f"Meta-learning save skipped: {e}")

        try:
            with get_db() as db:
                job = db.query(JobModel).filter(JobModel.id == job_id).first()
                if job:
                    job.status = "completed"

                    try:
                        job.results_json = json.dumps(results)
                    except Exception:
                        job.results_json = json.dumps({})

                    try:
                        job.insights_json = json.dumps(insights)
                    except Exception:
                        job.insights_json = json.dumps({})

                    job.story = story
                    job.model_path = results.get("model_path") if isinstance(results, dict) else None

                    reasoning = final_ctx.reasoning if isinstance(final_ctx.reasoning, list) else []
                    try:
                        job.reasoning_json = json.dumps(reasoning)
                    except Exception:
                        job.reasoning_json = json.dumps([str(r) for r in reasoning])

                    db.commit()
        except Exception as e:
            log.warning(f"Final DB write failed: {e}", exc_info=True)

    except Exception:
        raise