File size: 16,409 Bytes
df31aa1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
"""
Active Learning Module for Cognexa ML Service

Implements uncertainty-based active learning to identify the most informative
samples for human review, minimizing labeling cost while maximizing model improvement.

Strategies:
- Least confidence sampling: pick samples where model is least certain
- Margin sampling: smallest gap between top-2 class probabilities
- Entropy sampling: highest Shannon entropy across class probabilities
- Query-by-committee (QBC): disagreement between ensemble members
"""

from __future__ import annotations

import json
import logging
import uuid
from dataclasses import dataclass, asdict, field
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any

import numpy as np

logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# Data Structures
# ---------------------------------------------------------------------------

@dataclass
class UncertainSample:
    """A sample flagged by active learning as needing human review."""
    sample_id: str
    user_id: str
    task_id: Optional[str]
    features: Dict[str, float]
    current_prediction: Dict[str, Any]
    uncertainty_score: float          # 0-1; higher = more uncertain
    uncertainty_method: str           # 'least_confidence' | 'margin' | 'entropy' | 'qbc'
    query_priority: int               # 1 (highest) to 5 (lowest)
    created_at: str
    reviewed: bool = False
    review_label: Optional[Any] = None
    reviewer_id: Optional[str] = None
    reviewed_at: Optional[str] = None


@dataclass
class ActiveLearningBatch:
    """A batch of uncertain samples to present to reviewers."""
    batch_id: str
    strategy: str
    samples: List[UncertainSample]
    total_pool_size: int
    batch_size: int
    created_at: str
    expected_model_gain: float        # Estimated accuracy improvement if all labeled


@dataclass  
class ActiveLearningStats:
    """Aggregate statistics for active learning progress."""
    total_queried: int
    total_reviewed: int
    total_pending: int
    review_rate: float
    avg_uncertainty: float
    coverage_by_strategy: Dict[str, int]
    model_improvement_estimate: float


# ---------------------------------------------------------------------------
# Uncertainty Estimators
# ---------------------------------------------------------------------------

class UncertaintyEstimator:
    """Computes uncertainty scores from model predictions."""

    @staticmethod
    def least_confidence(probabilities: List[float]) -> float:
        """
        1 - max(p): how far the most-confident class is from certainty.
        Range [0, 1); higher means more uncertain.
        """
        if not probabilities:
            return 0.5
        return 1.0 - max(probabilities)

    @staticmethod
    def margin_confidence(probabilities: List[float]) -> float:
        """
        Margin between two highest probabilities.
        Range [0, 1); lower margin = higher uncertainty.
        Returns 1 - margin so higher = more uncertain.
        """
        if len(probabilities) < 2:
            return 0.5
        sorted_probs = sorted(probabilities, reverse=True)
        margin = sorted_probs[0] - sorted_probs[1]
        return 1.0 - margin

    @staticmethod
    def entropy(probabilities: List[float]) -> float:
        """
        Shannon entropy: H = -sum(p * log2(p)).
        Normalised to [0, 1] by dividing by log2(n_classes).
        """
        if not probabilities:
            return 0.5
        n = len(probabilities)
        if n == 1:
            return 0.0
        eps = 1e-10
        h = -sum(p * np.log2(p + eps) for p in probabilities if p > 0)
        return h / np.log2(n)

    @staticmethod
    def query_by_committee(predictions_list: List[List[float]]) -> float:
        """
        Vote entropy: how much committee members disagree.
        predictions_list: list of probability vectors from each committee member.
        Returns 0 if all agree, 1 if maximally disagree.
        """
        if not predictions_list or len(predictions_list) < 2:
            return 0.0
        n_members = len(predictions_list)
        n_classes = len(predictions_list[0])
        # Count votes for each class (argmax)
        votes = [int(np.argmax(p)) for p in predictions_list]
        vote_counts = np.bincount(votes, minlength=n_classes)
        vote_probs = vote_counts / n_members
        eps = 1e-10
        h = -sum(v * np.log2(v + eps) for v in vote_probs if v > 0)
        return h / np.log2(n_members) if n_members > 1 else 0.0


# ---------------------------------------------------------------------------
# Active Learning Selector
# ---------------------------------------------------------------------------

class ActiveLearningSelector:
    """
    Selects the most informative unlabeled samples from a candidate pool.
    """

    def __init__(self, strategy: str = "entropy", threshold: float = 0.65):
        """
        Args:
            strategy: 'least_confidence' | 'margin' | 'entropy' | 'qbc'
            threshold: minimum uncertainty score to include in batch (0-1)
        """
        self.strategy = strategy
        self.threshold = threshold
        self.estimator = UncertaintyEstimator()

    def score_sample(
        self,
        prediction: Dict[str, Any],
        committee_predictions: Optional[List[List[float]]] = None,
    ) -> float:
        """Compute an uncertainty score for a single prediction."""
        completion_prob = float(prediction.get("completion_probability", 0.5))
        delay_prob = float(prediction.get("delay_probability", 1.0 - completion_prob))
        probs = [completion_prob, delay_prob]

        if self.strategy == "least_confidence":
            return self.estimator.least_confidence(probs)
        elif self.strategy == "margin":
            return self.estimator.margin_confidence(probs)
        elif self.strategy == "entropy":
            return self.estimator.entropy(probs)
        elif self.strategy == "qbc" and committee_predictions:
            return self.estimator.query_by_committee(committee_predictions)
        else:
            # Default: entropy
            return self.estimator.entropy(probs)

    def select_batch(
        self,
        candidate_pool: List[Dict[str, Any]],
        batch_size: int = 20,
        user_id: Optional[str] = None,
    ) -> ActiveLearningBatch:
        """
        Select the top-k uncertain samples from candidate_pool.

        Args:
            candidate_pool: list of dicts with keys:
                - task_id (str)
                - features (Dict[str, float])
                - prediction (Dict[str, Any])
                - committee_predictions (optional List[List[float]])
            batch_size: how many samples to include
            user_id: optional user constraint

        Returns:
            ActiveLearningBatch with ranked uncertain samples.
        """
        scored: List[Tuple[float, Dict[str, Any]]] = []
        for candidate in candidate_pool:
            prediction = candidate.get("prediction", {})
            committee = candidate.get("committee_predictions")
            score = self.score_sample(prediction, committee)
            if score >= self.threshold:
                scored.append((score, candidate))

        # Sort descending by uncertainty
        scored.sort(key=lambda x: x[0], reverse=True)
        top_k = scored[:batch_size]

        samples = []
        for rank, (score, candidate) in enumerate(top_k, 1):
            priority = min(5, max(1, int((1.0 - score) * 5) + 1))
            sample = UncertainSample(
                sample_id=str(uuid.uuid4()),
                user_id=user_id or candidate.get("user_id", "unknown"),
                task_id=candidate.get("task_id"),
                features=candidate.get("features", {}),
                current_prediction=candidate.get("prediction", {}),
                uncertainty_score=round(score, 4),
                uncertainty_method=self.strategy,
                query_priority=priority,
                created_at=datetime.utcnow().isoformat(),
            )
            samples.append(sample)

        # Estimate model gain (heuristic: based on avg uncertainty of selected batch)
        avg_uncertainty = np.mean([s.uncertainty_score for s in samples]) if samples else 0.0
        model_gain = avg_uncertainty * 0.05  # ~5% improvement per 1.0 of uncertainty

        return ActiveLearningBatch(
            batch_id=str(uuid.uuid4()),
            strategy=self.strategy,
            samples=samples,
            total_pool_size=len(candidate_pool),
            batch_size=len(samples),
            created_at=datetime.utcnow().isoformat(),
            expected_model_gain=round(model_gain, 4),
        )


# ---------------------------------------------------------------------------
# Active Learning Manager (persistence + orchestration)
# ---------------------------------------------------------------------------

class ActiveLearningManager:
    """
    Manages the active learning pipeline:
    - Stores uncertain sample batches
    - Tracks which have been reviewed
    - Provides stats for the dashboard
    """

    def __init__(self, data_dir: str = "data/active_learning"):
        self.data_dir = Path(data_dir)
        self.data_dir.mkdir(parents=True, exist_ok=True)
        self.pending_file = self.data_dir / "pending_samples.json"
        self.reviewed_file = self.data_dir / "reviewed_samples.json"
        self._pending: List[UncertainSample] = self._load(self.pending_file)
        self._reviewed: List[UncertainSample] = self._load(self.reviewed_file)

    # -- Persistence ----------------------------------------------------------

    def _load(self, path: Path) -> List[UncertainSample]:
        if not path.exists():
            return []
        try:
            with open(path) as f:
                raw = json.load(f)
            return [UncertainSample(**item) for item in raw]
        except Exception as e:
            logger.warning("Could not load %s: %s", path, e)
            return []

    def _save_pending(self):
        with open(self.pending_file, "w") as f:
            json.dump([asdict(s) for s in self._pending], f, indent=2)

    def _save_reviewed(self):
        with open(self.reviewed_file, "w") as f:
            json.dump([asdict(s) for s in self._reviewed], f, indent=2)

    # -- Public API -----------------------------------------------------------

    def add_batch(self, batch: ActiveLearningBatch):
        """Persist a new batch of uncertain samples."""
        self._pending.extend(batch.samples)
        self._save_pending()
        logger.info(
            "Active learning: %d samples added (strategy=%s)", len(batch.samples), batch.strategy
        )

    def get_pending_samples(
        self,
        user_id: Optional[str] = None,
        limit: int = 20,
    ) -> List[UncertainSample]:
        """Retrieve pending (unreviewed) samples for review."""
        samples = [s for s in self._pending if not s.reviewed]
        if user_id:
            samples = [s for s in samples if s.user_id == user_id]
        # Prioritise by uncertainty score descending
        samples.sort(key=lambda s: s.uncertainty_score, reverse=True)
        return samples[:limit]

    def submit_review(
        self,
        sample_id: str,
        label: Any,
        reviewer_id: Optional[str] = None,
    ) -> bool:
        """Mark a sample as reviewed with a human-provided label."""
        for sample in self._pending:
            if sample.sample_id == sample_id and not sample.reviewed:
                sample.reviewed = True
                sample.review_label = label
                sample.reviewer_id = reviewer_id
                sample.reviewed_at = datetime.utcnow().isoformat()
                self._reviewed.append(sample)
                self._pending = [s for s in self._pending if s.sample_id != sample_id]
                self._save_pending()
                self._save_reviewed()
                logger.info("Sample %s reviewed with label=%s", sample_id, label)
                return True
        return False

    def get_reviewed_samples(
        self,
        since_hours: int = 168,  # default: last 7 days
        limit: int = 200,
    ) -> List[UncertainSample]:
        """Retrieve recently reviewed samples (used to trigger retraining)."""
        cutoff = datetime.utcnow() - timedelta(hours=since_hours)
        results = [
            s for s in self._reviewed
            if s.reviewed_at and datetime.fromisoformat(s.reviewed_at) > cutoff
        ]
        return results[:limit]

    def get_training_data(self) -> List[Dict[str, Any]]:
        """Export reviewed samples as training records."""
        records = []
        for s in self._reviewed:
            if s.review_label is not None:
                record = {
                    **s.features,
                    "label": s.review_label,
                    "task_id": s.task_id,
                    "user_id": s.user_id,
                    "reviewed_at": s.reviewed_at,
                }
                records.append(record)
        return records

    def get_stats(self) -> ActiveLearningStats:
        """Aggregate stats for the active learning dashboard."""
        all_samples = self._pending + self._reviewed
        reviewed = [s for s in all_samples if s.reviewed]
        pending = [s for s in self._pending if not s.reviewed]

        strategy_counts: Dict[str, int] = {}
        for s in all_samples:
            strategy_counts[s.uncertainty_method] = (
                strategy_counts.get(s.uncertainty_method, 0) + 1
            )

        avg_unc = (
            float(np.mean([s.uncertainty_score for s in pending])) if pending else 0.0
        )

        # Rough model gain estimate: each reviewed sample reduces uncertainty by ~0.1%
        model_gain = min(0.20, len(reviewed) * 0.001)

        return ActiveLearningStats(
            total_queried=len(all_samples),
            total_reviewed=len(reviewed),
            total_pending=len(pending),
            review_rate=len(reviewed) / max(1, len(all_samples)),
            avg_uncertainty=round(avg_unc, 4),
            coverage_by_strategy=strategy_counts,
            model_improvement_estimate=round(model_gain, 4),
        )

    def should_retrain(self, min_new_samples: int = 50) -> bool:
        """Return True if enough new reviewed samples warrant model retraining."""
        new_samples = self.get_reviewed_samples(since_hours=24)
        return len(new_samples) >= min_new_samples

    def export_for_retraining(self) -> Dict[str, Any]:
        """Export all reviewed data ready for model retraining."""
        training_data = self.get_training_data()
        return {
            "records": training_data,
            "count": len(training_data),
            "exported_at": datetime.utcnow().isoformat(),
            "ready_for_training": len(training_data) >= 10,
        }


# ---------------------------------------------------------------------------
# Convenience factory (singleton)
# ---------------------------------------------------------------------------

_manager_instance: Optional[ActiveLearningManager] = None


def get_active_learning_manager() -> ActiveLearningManager:
    global _manager_instance
    if _manager_instance is None:
        _manager_instance = ActiveLearningManager()
    return _manager_instance


def run_active_learning_query(
    candidate_pool: List[Dict[str, Any]],
    strategy: str = "entropy",
    batch_size: int = 20,
    user_id: Optional[str] = None,
    threshold: float = 0.55,
) -> Dict[str, Any]:
    """
    High-level entrypoint: score a pool of prediction candidates and return
    the most uncertain subset formatted for the REST API.
    """
    selector = ActiveLearningSelector(strategy=strategy, threshold=threshold)
    batch = selector.select_batch(candidate_pool, batch_size=batch_size, user_id=user_id)
    manager = get_active_learning_manager()
    manager.add_batch(batch)
    return {
        "batch_id": batch.batch_id,
        "strategy": batch.strategy,
        "samples_selected": batch.batch_size,
        "pool_size": batch.total_pool_size,
        "expected_model_gain": batch.expected_model_gain,
        "samples": [asdict(s) for s in batch.samples],
        "created_at": batch.created_at,
    }